...
 
Commits (5)
......@@ -3,7 +3,6 @@ import torch
import torch.optim as optim
import argparse
import data
import models
device = torch.device('cuda')
......
import models
import torch
import torch.optim as optim
import argparse
import data
device = torch.device('cuda')
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes'])
jpeg_model = models.JpegResNetExact(spatial_model).to(device)
train_set, test_set = data.jpeg_dataset_map[args.dataset](128, args.data)
optimizer = optim.Adam(jpeg_model.parameters())
for e in range(5):
models.train(jpeg_model, device, train_set, optimizer, e)
models.test(jpeg_model, device, test_set)
......@@ -5,19 +5,45 @@ class BatchNorm(torch.nn.modules.Module):
def __init__(self, bn):
super(BatchNorm, self).__init__()
self.mean = bn.running_mean
self.var = bn.running_var
self.register_buffer('running_mean', bn.running_mean)
self.register_buffer('running_var', bn.running_var)
self.register_buffer('num_batches_tracked', bn.num_batches_tracked)
self.gamma = bn.weight
self.beta = bn.bias
gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
self.register_buffer('gamma_final', gamma_final)
beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
self.register_buffer('beta_final', beta_final)
self.register_parameter('gamma', self.gamma)
self.register_parameter('beta', self.beta)
def forward(self, input):
input = input * self.gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + self.beta_final
if self.training:
# Compute the batch mean for each channel
channels = input.shape[1]
block_means = input[:, :, :, :, 0].permute(1, 0, 2, 3).contiguous().view(channels, -1) # channels x everything else
batch_mean = torch.mean(block_means, 1)
# Compute the batch variance for each channel
input[:, :, :, :, 0] = 0 # zero mean
batch_var = torch.mean(input.permute(1, 0, 2, 3, 4).contiguous().view(channels, -1)**2, 1)
# Apply parameters
bv = batch_var.view(1, batch_var.shape[0], 1, 1, 1)
b = self.beta.view(1, self.beta.shape[0], 1, 1)
g = self.gamma.view(1, self.gamma.shape[0], 1, 1, 1)
input = (input / torch.sqrt(bv)) * g
input[:, :, :, :, 0] = input[:, :, :, :, 0] + b
# Update running stats
self.running_mean = (self.running_mean * self.num_batches_tracked.float() + batch_mean) / (self.num_batches_tracked.float() + 1)
self.running_var = (self.running_var * self.num_batches_tracked.float() + batch_var) / (self.num_batches_tracked.float() + 1)
self.num_batches_tracked += 1
else:
gamma_final = (self.gamma / torch.sqrt(self.running_var)).view(1, self.gamma.shape[0], 1, 1, 1)
beta_final = (self.beta - (self.gamma * self.running_mean) / torch.sqrt(self.running_var)).view(1, self.beta.shape[0], 1, 1)
input = input * gamma_final
input[:, :, :, :, 0] = input[:, :, :, :, 0] + beta_final
return input
......@@ -11,6 +11,8 @@ class Conv2d(torch.nn.modules.Module):
self.weight = conv_spatial.weight
self.padding = conv_spatial.padding
self.register_parameter('weight', self.weight)
self.register_buffer('J', J[0])
self.register_buffer('J_i', J[1])
......
......@@ -9,7 +9,7 @@ def train(model, device, train_loader, optimizer, epoch):
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward(retain_graph=True)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
......