Verified Commit 5b29b727 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Correct the batch norm forward training algorithm

parent 40cef28c
import torch import torch
from jpeg_codec import S_i
class BatchNorm(torch.nn.modules.Module): class BatchNorm(torch.nn.modules.Module):
def __init__(self, bn): def __init__(self, bn):
super(BatchNorm, self).__init__() super(BatchNorm, self).__init__()
self.register_buffer('running_mean', bn.running_mean) self.register_buffer('running_mean', bn.running_mean.clone())
self.register_buffer('running_var', bn.running_var) self.register_buffer('running_var', bn.running_var.clone())
self.register_buffer('num_batches_tracked', bn.num_batches_tracked) self.register_buffer('S_i', S_i())
self.gamma = bn.weight self.momentum = bn.momentum
self.beta = bn.bias self.eps = bn.eps
self.register_parameter('gamma', self.gamma) self.gamma = torch.nn.Parameter(bn.weight.clone())
self.register_parameter('beta', self.beta) self.beta = torch.nn.Parameter(bn.bias.clone())
def forward(self, input): def forward(self, input):
if self.training: if self.training:
# Compute the batch mean for each channel
channels = input.shape[1] channels = input.shape[1]
block_means = input[:, :, :, :, 0].permute(1, 0, 2, 3).contiguous().view(channels, -1) # channels x everything else
input_channelwise = input.permute(1, 0, 2, 3, 4).clone()
# Compute the batch mean for each channel
block_means = input_channelwise[:, :, :, :, 0].contiguous().view(channels, -1)
batch_mean = torch.mean(block_means, 1) batch_mean = torch.mean(block_means, 1)
# Compute the batch variance for each channel # Compute the batch variance for each channel
input[:, :, :, :, 0] = 0 # zero mean input_dequantized = torch.einsum('mtxyk,gk->mtxyg', [input_channelwise, self.S_i])
batch_var = torch.mean(input.permute(1, 0, 2, 3, 4).contiguous().view(channels, -1)**2, 1) input_dequantized[:, :, :, :, 0] = 0 # zero mean
block_variances = torch.mean(input_dequantized**2, 4).view(channels, -1)
# Apply parameters batch_var = torch.mean(block_variances + block_means**2, 1) - batch_mean**2
bv = batch_var.view(1, batch_var.shape[0], 1, 1, 1)
b = self.beta.view(1, self.beta.shape[0], 1, 1)
g = self.gamma.view(1, self.gamma.shape[0], 1, 1, 1)
input = (input / torch.sqrt(bv)) * g # Apply bessel correction to match pytorch i dont think this is really necessary
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b bessel_correction_factor = input.shape[0] * input.shape[2] * input.shape[3] * 64
bessel_correction_factor = bessel_correction_factor / (bessel_correction_factor - 1)
batch_var *= bessel_correction_factor
batch_var = batch_var
# Update running stats # Update running stats
self.running_mean = (self.running_mean * self.num_batches_tracked.float() + batch_mean) / (self.num_batches_tracked.float() + 1) self.running_mean = self.running_mean * (1 - self.momentum) + batch_mean * self.momentum
self.running_var = (self.running_var * self.num_batches_tracked.float() + batch_var) / (self.num_batches_tracked.float() + 1) self.running_var = self.running_var * (1 - self.momentum) + batch_var * self.momentum
self.num_batches_tracked += 1 # Apply parameters
invstd = 1. / torch.sqrt(batch_var + self.eps).view(1, -1, 1, 1, 1)
mean = batch_mean.view(1, -1, 1, 1)
else: else:
gamma_final = (self.gamma / torch.sqrt(self.running_var)).view(1, self.gamma.shape[0], 1, 1, 1) invstd = 1. / torch.sqrt(self.running_var + self.eps).view(1, -1, 1, 1, 1)
beta_final = (self.beta - (self.gamma * self.running_mean) / torch.sqrt(self.running_var)).view(1, self.beta.shape[0], 1, 1) mean = self.running_mean.view(1, -1, 1, 1)
input = input * gamma_final g = self.gamma.view(1, -1, 1, 1, 1)
input[:, :, :, :, 0] = input[:, :, :, :, 0] + beta_final b = self.beta.view(1, -1, 1, 1)
input[:, :, :, :, 0] = input[:, :, :, :, 0] - mean
input = input * invstd
input = input * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
return input return input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment