Batch norm training mode first attempt

parent 79824871
......@@ -5,19 +5,45 @@ class BatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(BatchNorm, self).__init__()
self.mean = bn.running_mean
self.var = bn.running_var
self.register_buffer('running_mean', bn.running_mean)
self.register_buffer('running_var', bn.running_var)
self.register_buffer('num_batches_tracked', bn.num_batches_tracked)
self.gamma = bn.weight
self.beta = bn.bias
gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
self.register_buffer('gamma_final', gamma_final)
beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
self.register_buffer('beta_final', beta_final)
self.register_parameter('gamma', self.gamma)
self.register_parameter('beta', self.beta)
def forward(self, input):
input = input * self.gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + self.beta_final
if self.training:
# Compute the batch mean for each channel
channels = input.shape[1]
block_means = input[:, :, :, :, 0].permute(1, 0, 2, 3).contiguous().view(channels, -1) # channels x everything else
batch_mean = torch.mean(block_means, 1)
# Compute the batch variance for each channel
input[:, :, :, :, 0] = 0 # zero mean
batch_var = torch.mean(input.permute(1, 0, 2, 3, 4).contiguous().view(channels, -1)**2, 1)
# Apply parameters
bv = batch_var.view(1, batch_var.shape[0], 1, 1, 1)
b = self.beta.view(1, self.beta.shape[0], 1, 1)
g = self.gamma.view(1, self.gamma.shape[0], 1, 1, 1)
input = (input / torch.sqrt(bv)) * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
# Update running stats
self.running_mean = (self.running_mean * self.num_batches_tracked.float() + batch_mean) / (self.num_batches_tracked.float() + 1)
self.running_var = (self.running_var * self.num_batches_tracked.float() + batch_var) / (self.num_batches_tracked.float() + 1)
self.num_batches_tracked += 1
else:
gamma_final = (self.gamma / torch.sqrt(self.running_var)).view(1, self.gamma.shape[0], 1, 1, 1)
beta_final = (self.beta - (self.gamma * self.running_mean) / torch.sqrt(self.running_var)).view(1, self.beta.shape[0], 1, 1)
input = input * gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + beta_final
return input
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment