Correct the batch norm forward training algorithm

parent 40cef28c
import torch
from jpeg_codec import S_i
class BatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(BatchNorm, self).__init__()
self.register_buffer('running_mean', bn.running_mean)
self.register_buffer('running_var', bn.running_var)
self.register_buffer('num_batches_tracked', bn.num_batches_tracked)
self.register_buffer('running_mean', bn.running_mean.clone())
self.register_buffer('running_var', bn.running_var.clone())
self.register_buffer('S_i', S_i())
self.gamma = bn.weight
self.beta = bn.bias
self.momentum = bn.momentum
self.eps = bn.eps
self.register_parameter('gamma', self.gamma)
self.register_parameter('beta', self.beta)
self.gamma = torch.nn.Parameter(bn.weight.clone())
self.beta = torch.nn.Parameter(bn.bias.clone())
def forward(self, input):
if self.training:
# Compute the batch mean for each channel
channels = input.shape[1]
block_means = input[:, :, :, :, 0].permute(1, 0, 2, 3).contiguous().view(channels, -1) # channels x everything else
input_channelwise = input.permute(1, 0, 2, 3, 4).clone()
# Compute the batch mean for each channel
block_means = input_channelwise[:, :, :, :, 0].contiguous().view(channels, -1)
batch_mean = torch.mean(block_means, 1)
# Compute the batch variance for each channel
input[:, :, :, :, 0] = 0 # zero mean
batch_var = torch.mean(input.permute(1, 0, 2, 3, 4).contiguous().view(channels, -1)**2, 1)
input_dequantized = torch.einsum('mtxyk,gk->mtxyg', [input_channelwise, self.S_i])
input_dequantized[:, :, :, :, 0] = 0 # zero mean
block_variances = torch.mean(input_dequantized**2, 4).view(channels, -1)
batch_var = torch.mean(block_variances + block_means**2, 1) - batch_mean**2
# Apply parameters
bv = batch_var.view(1, batch_var.shape[0], 1, 1, 1)
b = self.beta.view(1, self.beta.shape[0], 1, 1)
g = self.gamma.view(1, self.gamma.shape[0], 1, 1, 1)
input = (input / torch.sqrt(bv)) * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
# Apply bessel correction to match pytorch i dont think this is really necessary
bessel_correction_factor = input.shape[0] * input.shape[2] * input.shape[3] * 64
bessel_correction_factor = bessel_correction_factor / (bessel_correction_factor - 1)
batch_var *= bessel_correction_factor
batch_var = batch_var
# Update running stats
self.running_mean = (self.running_mean * self.num_batches_tracked.float() + batch_mean) / (self.num_batches_tracked.float() + 1)
self.running_var = (self.running_var * self.num_batches_tracked.float() + batch_var) / (self.num_batches_tracked.float() + 1)
self.running_mean = self.running_mean * (1 - self.momentum) + batch_mean * self.momentum
self.running_var = self.running_var * (1 - self.momentum) + batch_var * self.momentum
self.num_batches_tracked += 1
# Apply parameters
invstd = 1. / torch.sqrt(batch_var + self.eps).view(1, -1, 1, 1, 1)
mean = batch_mean.view(1, -1, 1, 1)
else:
gamma_final = (self.gamma / torch.sqrt(self.running_var)).view(1, self.gamma.shape[0], 1, 1, 1)
beta_final = (self.beta - (self.gamma * self.running_mean) / torch.sqrt(self.running_var)).view(1, self.beta.shape[0], 1, 1)
invstd = 1. / torch.sqrt(self.running_var + self.eps).view(1, -1, 1, 1, 1)
mean = self.running_mean.view(1, -1, 1, 1)
g = self.gamma.view(1, -1, 1, 1, 1)
b = self.beta.view(1, -1, 1, 1)
input = input * gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + beta_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] - mean
input = input * invstd
input = input * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
return input
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment