...
 
Commits (3)
import torch
from jpeg_codec import S_i
class BatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(BatchNorm, self).__init__()
self.register_buffer('running_mean', bn.running_mean)
self.register_buffer('running_var', bn.running_var)
self.register_buffer('num_batches_tracked', bn.num_batches_tracked)
self.register_buffer('running_mean', bn.running_mean.clone())
self.register_buffer('running_var', bn.running_var.clone())
self.register_buffer('S_i', S_i())
self.gamma = bn.weight
self.beta = bn.bias
self.momentum = bn.momentum
self.eps = bn.eps
self.register_parameter('gamma', self.gamma)
self.register_parameter('beta', self.beta)
self.gamma = torch.nn.Parameter(bn.weight.clone())
self.beta = torch.nn.Parameter(bn.bias.clone())
def forward(self, input):
if self.training:
# Compute the batch mean for each channel
channels = input.shape[1]
block_means = input[:, :, :, :, 0].permute(1, 0, 2, 3).contiguous().view(channels, -1) # channels x everything else
input_channelwise = input.permute(1, 0, 2, 3, 4).clone()
# Compute the batch mean for each channel
block_means = input_channelwise[:, :, :, :, 0].contiguous().view(channels, -1)
batch_mean = torch.mean(block_means, 1)
# Compute the batch variance for each channel
input[:, :, :, :, 0] = 0 # zero mean
batch_var = torch.mean(input.permute(1, 0, 2, 3, 4).contiguous().view(channels, -1)**2, 1)
input_dequantized = torch.einsum('mtxyk,gk->mtxyg', [input_channelwise, self.S_i])
input_dequantized[:, :, :, :, 0] = 0 # zero mean
block_variances = torch.mean(input_dequantized**2, 4).view(channels, -1)
batch_var = torch.mean(block_variances + block_means**2, 1) - batch_mean**2
# Apply parameters
bv = batch_var.view(1, batch_var.shape[0], 1, 1, 1)
b = self.beta.view(1, self.beta.shape[0], 1, 1)
g = self.gamma.view(1, self.gamma.shape[0], 1, 1, 1)
input = (input / torch.sqrt(bv)) * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
# Apply bessel correction to match pytorch i dont think this is really necessary
bessel_correction_factor = input.shape[0] * input.shape[2] * input.shape[3] * 64
bessel_correction_factor = bessel_correction_factor / (bessel_correction_factor - 1)
batch_var *= bessel_correction_factor
batch_var = batch_var
# Update running stats
self.running_mean = (self.running_mean * self.num_batches_tracked.float() + batch_mean) / (self.num_batches_tracked.float() + 1)
self.running_var = (self.running_var * self.num_batches_tracked.float() + batch_var) / (self.num_batches_tracked.float() + 1)
self.running_mean = self.running_mean * (1 - self.momentum) + batch_mean * self.momentum
self.running_var = self.running_var * (1 - self.momentum) + batch_var * self.momentum
self.num_batches_tracked += 1
# Apply parameters
invstd = 1. / torch.sqrt(batch_var + self.eps).view(1, -1, 1, 1, 1)
mean = batch_mean.view(1, -1, 1, 1)
else:
gamma_final = (self.gamma / torch.sqrt(self.running_var)).view(1, self.gamma.shape[0], 1, 1, 1)
beta_final = (self.beta - (self.gamma * self.running_mean) / torch.sqrt(self.running_var)).view(1, self.beta.shape[0], 1, 1)
invstd = 1. / torch.sqrt(self.running_var + self.eps).view(1, -1, 1, 1, 1)
mean = self.running_mean.view(1, -1, 1, 1)
g = self.gamma.view(1, -1, 1, 1, 1)
b = self.beta.view(1, -1, 1, 1)
input = input * gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + beta_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] - mean
input = input * invstd
input = input * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
return input
......@@ -8,10 +8,8 @@ class Conv2d(torch.nn.modules.Module):
super(Conv2d, self).__init__()
self.stride = conv_spatial.stride
self.weight = conv_spatial.weight
self.padding = conv_spatial.padding
self.register_parameter('weight', self.weight)
self.weight = torch.nn.Parameter(conv_spatial.weight.clone())
self.register_buffer('J', J[0])
self.register_buffer('J_i', J[1])
......
......@@ -13,42 +13,4 @@ set key autotitle columnhead
unset colorbox
set key tmargin horizontal
set key above vertical maxrows 3
# method colors
apx_color = 0.85
asm_color = 0.35
# dataset point types
mnist_point = 7
cifar10_point = 9
cifar100_point = 5
block_point = 2
# spatial lines
set style line 1 linewidth 8 dashtype 2 linetype rgb "black"
spatial = 1
# mnist lines
set style line 2 linewidth 8 pointsize 4 pointtype mnist_point palette frac asm_color
set style line 3 linewidth 8 pointsize 4 pointtype mnist_point palette frac apx_color
asm_mnist = 2
apx_mnist = 3
# cifar10 lines
set style line 4 linewidth 8 pointsize 5 pointtype cifar10_point palette frac asm_color
set style line 5 linewidth 8 pointsize 5 pointtype cifar10_point palette frac apx_color
asm_cifar10 = 4
apx_cifar10 = 5
# cifar100 lines
set style line 6 linewidth 8 pointsize 4 pointtype cifar100_point palette frac asm_color
set style line 7 linewidth 8 pointsize 4 pointtype cifar100_point palette frac apx_color
asm_cifar100 = 6
apx_cifar100 = 7
# block lines
set style line 8 linewidth 8 pointsize 4 pointtype block_point palette frac asm_color
set style line 9 linewidth 8 pointsize 4 pointtype block_point palette frac apx_color
asm_block = 8
apx_block = 9
\ No newline at end of file
set key above vertical maxrows 3
\ No newline at end of file
#!/usr/bin/gnuplot -c
load "common.gp"
load "relu_styles.gp"
set ylabel 'Average Accuracy (%)'
......
#!/usr/bin/gnuplot -c
load "common.gp"
load "relu_styles.gp"
set ylabel 'Average RMSE'
......
# method colors
apx_color = 0.85
asm_color = 0.35
# dataset point types
mnist_point = 7
cifar10_point = 9
cifar100_point = 5
block_point = 2
# spatial lines
set style line 1 linewidth 8 dashtype 2 linetype rgb "black"
spatial = 1
# mnist lines
set style line 2 linewidth 8 pointsize 4 pointtype mnist_point palette frac asm_color
set style line 3 linewidth 8 pointsize 4 pointtype mnist_point palette frac apx_color
asm_mnist = 2
apx_mnist = 3
# cifar10 lines
set style line 4 linewidth 8 pointsize 5 pointtype cifar10_point palette frac asm_color
set style line 5 linewidth 8 pointsize 5 pointtype cifar10_point palette frac apx_color
asm_cifar10 = 4
apx_cifar10 = 5
# cifar100 lines
set style line 6 linewidth 8 pointsize 4 pointtype cifar100_point palette frac asm_color
set style line 7 linewidth 8 pointsize 4 pointtype cifar100_point palette frac apx_color
asm_cifar100 = 6
apx_cifar100 = 7
# block lines
set style line 8 linewidth 8 pointsize 4 pointtype block_point palette frac asm_color
set style line 9 linewidth 8 pointsize 4 pointtype block_point palette frac apx_color
asm_block = 8
apx_block = 9
\ No newline at end of file