...
 
Commits (12)
from .datasets import *
\ No newline at end of file
import torch
from torchvision import datasets, transforms
from jpeg_dataset import JpegDataset
from .jpeg_dataset import JpegDataset
def spatial_data(batch_size, root, name, dataset, transform, shuffle_train):
......@@ -20,7 +20,8 @@ def mnist_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.1307,), (0.3081,))
])
return spatial_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
directory = '{}/{}'.format(root, 'MNIST')
return spatial_data(batch_size=batch_size, root=directory, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
def cifar10_spatial(batch_size, root, shuffle_train=True):
......@@ -29,7 +30,8 @@ def cifar10_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
return spatial_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
directory = '{}/{}'.format(root, 'CIFAR10')
return spatial_data(batch_size=batch_size, root=directory, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
def cifar100_spatial(batch_size, root, shuffle_train=True):
......@@ -38,7 +40,8 @@ def cifar100_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
return spatial_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
directory = '{}/{}'.format(root, 'CIFAR100')
return spatial_data(batch_size=batch_size, root=directory, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
def jpeg_data(batch_size, directory, shuffle_train):
......@@ -64,3 +67,31 @@ def cifar10_jpeg(batch_size, root, shuffle_train=True):
def cifar100_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR100')
return jpeg_data(batch_size, directory, shuffle_train)
spatial_dataset_map = {
'MNIST': mnist_spatial,
'CIFAR10': cifar10_spatial,
'CIFAR100': cifar100_spatial
}
jpeg_dataset_map = {
'MNIST': mnist_jpeg,
'CIFAR10': cifar10_jpeg,
'CIFAR100': cifar100_jpeg
}
dataset_info = {
'MNIST': {
'channels': 1,
'classes': 10
},
'CIFAR10': {
'channels': 3,
'classes': 10
},
'CIFAR100': {
'channels': 3,
'classes': 100,
}
}
import torch
from jpeg_readwrite import load_jpeg
from .jpeg_readwrite import load_jpeg
class JpegDataset(torch.utils.data.Dataset):
......@@ -10,4 +10,4 @@ class JpegDataset(torch.utils.data.Dataset):
return self.data.shape[0]
def __getitem__(self, idx):
return self.dataset[idx], self.labels[idx]
return self.data[idx], self.labels[idx]
......@@ -3,21 +3,15 @@ from .datasets import *
from .jpeg_readwrite import *
from jpeg_codec import encode
dataset_map = {
'MNIST': mnist_spatial,
'CIFAR10': cifar10_spatial,
'CIFAR100': cifar100_spatial
}
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, help='Number of images to convert at a time', default=60000)
parser.add_argument('dataset', choices=dataset_map.keys(), help='Dataset to convert')
parser.add_argument('dataset', choices=spatial_dataset_map.keys(), help='Dataset to convert')
parser.add_argument('directory', help='directory to load and store datasets')
args = parser.parse_args()
print(args)
train_data, test_data = dataset_map[args.dataset](args.batch_size, root=args.directory, shuffle_train=False)
train_data, test_data = spatial_dataset_map[args.dataset](args.batch_size, root=args.directory, shuffle_train=False)
device = torch.device('cuda')
......
......@@ -13,10 +13,9 @@ def write_jpeg(data, labels, directory, fname):
def load_jpeg(directory, fname):
path = '{}/{}'.format(directory, fname)
with open(path) as f:
with open(path, 'rb') as f:
dl = np.load(f)
data = dl['data']
labels = dl['labels']
data = dl['data']
label = dl['label']
return torch.Tensor(data), torch.Tensor(label)
return torch.Tensor(data), torch.LongTensor(labels)
import models
import torch
import torch.optim as optim
import argparse
import data
import models
device = torch.device('cuda')
parser = argparse.ArgumentParser()
parser.add_argument('--models', type=int, help='Number of models to use')
parser.add_argument('--epochs', type=int, help='Number of epochs to train for')
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()
spatial_accuracies = 0
jpeg_accuracies = 0
for m in range(args.models):
print('Train spatial model {}/{}'.format(m+1, args.models))
spatial_dataset = data.spatial_dataset_map[args.dataset](args.batch_size, args.data)
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes']).to(device)
optimizer = optim.Adam(spatial_model.parameters())
for e in range(args.epochs):
models.train(spatial_model, device, spatial_dataset[0], optimizer, e)
models.test(spatial_model, device, spatial_dataset[1])
acc = models.test(spatial_model, device, spatial_dataset[1])
spatial_accuracies += acc
print('Convert model to JPEG')
jpeg_model = models.JpegResNetExact(spatial_model).to(device)
jpeg_model.explode_all()
print('Test JPEG model')
jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data)
acc = models.test(jpeg_model, device, jpeg_dataset[1])
jpeg_accuracies += acc
spatial_accuracies /= args.models
jpeg_accuracies /= args.models
print('Report')
print('======')
print('Dataset: {}'.format(args.dataset))
print('Number of Models: {}'.format(args.models))
print('Average Spatial Accuracy: {}'.format(spatial_accuracies))
print('Average JPEG Accuracy: {}'.format(jpeg_accuracies))
print('Deviation: {}'.format(spatial_accuracies - jpeg_accuracies))
......@@ -11,8 +11,11 @@ class BatchNorm(torch.nn.modules.Module):
self.gamma = bn.weight
self.beta = bn.bias
self.gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
self.beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
self.register_buffer('gamma_final', gamma_final)
beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
self.register_buffer('beta_final', beta_final)
def forward(self, input):
input = input * self.gamma_final
......
......@@ -3,47 +3,53 @@ import opt_einsum as oe
import numpy as np
class Conv2dBase(torch.nn.modules.Module):
class Conv2d(torch.nn.modules.Module):
def __init__(self, conv_spatial, J):
super(Conv2dBase, self).__init__()
super(Conv2d, self).__init__()
self.stride = conv_spatial.stride
self.weight = conv_spatial.weight
self.padding = conv_spatial.padding
self.J = J
self.J_batched = self.J[1].contiguous().view(np.prod(self.J[1].shape[0:3]), 1, *self.J[1].shape[3:5])
self.register_buffer('J', J[0])
self.register_buffer('J_i', J[1])
input_shape = [0, self.weight.shape[1], *self.J[1].shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J[1].shape[0:3], *self.J[0].shape[0:2]]
J_batched = self.J_i.contiguous().view(np.prod(self.J_i.shape[0:3]), 1, *self.J_i.shape[3:5])
self.register_buffer('J_batched', J_batched)
self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc', jpeg_op_shape, self.J[0], input_shape, constants=[1], optimize='optimal')
self.make_apply_op()
self.jpeg_op = None
def make_apply_op(self):
input_shape = [0, self.weight.shape[1], *self.J_i.shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J_i.shape[0:3], *self.J.shape[0:2]]
self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc', jpeg_op_shape, self.J, input_shape, constants=[1], optimize='optimal')
self.apply_conv.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(Conv2d, self)._apply(fn)
s.make_apply_op()
return s
def explode(self):
out_channels = self.weight.shape[0]
in_channels = self.weight.shape[1]
jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride)
jpeg_op = jpeg_op.permute(1, 0, 2, 3)
jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J[1].shape[0:3], *(np.array(self.J[1].shape[3:5]) // self.stride))
jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J_i.shape[0:3], *(np.array(self.J_i.shape[3:5]) // self.stride))
return jpeg_op
class Conv2d(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2d, self).__init__(conv_spatial, J)
def forward(self, input):
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
class Conv2dPre(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2dPre, self).__init__(conv_spatial, J)
def explode_pre(self):
self.jpeg_op = self.explode()
def forward(self, input):
return self.apply_conv(self.jpeg_op, input, backend='torch')
if self.jpeg_op is not None:
jpeg_op = self.jpeg_op
else:
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
......@@ -6,15 +6,26 @@ from jpeg_codec import D_n, D, Z, S_i, S
class ReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(ReLU, self).__init__()
self.C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
self.register_buffer('Hm', Hm)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
self.hsm_op.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(ReLU, self)._apply(fn)
s.make_masking_ops()
return s
def annm(self, x):
appx_im = self.annm_op(x, backend='torch')
mask = torch.zeros_like(appx_im)
......
......@@ -17,8 +17,12 @@ class SpatialResBlock(nn.Module):
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2, padding=0, bias=False)
if downsample or in_channels != out_channels:
stride = 2 if downsample else 1
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
self.bn_ds = nn.BatchNorm2d(out_channels)
else:
self.downsampler = None
def forward(self, x):
out = self.conv1(x)
......@@ -28,8 +32,9 @@ class SpatialResBlock(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
if self.downsampler is not None:
residual = self.downsampler(x)
residual = self.bn_ds(residual)
else:
residual = x
......@@ -53,11 +58,18 @@ class JpegResBlock(nn.Module):
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
if spatial_resblock.downsample:
self.downsample = True
if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
self.bn_ds = jpeg_layers.BatchNorm(spatial_resblock.bn_ds)
else:
self.downsample = False
self.downsampler = None
def explode_all(self):
self.conv1.explode_pre()
self.conv2.explode_pre()
if self.downsampler is not None:
self.downsampler.explode_pre()
def forward(self, x):
out = self.conv1(x)
......@@ -67,8 +79,9 @@ class JpegResBlock(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
if self.downsampler is not None:
residual = self.downsampler(x)
residual = self.bn_ds(residual)
else:
residual = x
......
......@@ -19,6 +19,11 @@ class JpegResNet(nn.Module):
self.averagepooling = AvgPool()
self.fc = spatial_model.fc
def explode_all(self):
self.block1.explode_all()
self.block2.explode_all()
self.block3.explode_all()
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
......@@ -34,4 +39,4 @@ class JpegResNet(nn.Module):
class JpegResNetExact(JpegResNet):
def __init__(self, spatial_model):
super(JpegResNet, self).__init__(spatial_model, 14)
super(JpegResNetExact, self).__init__(spatial_model, 14)
......@@ -3,15 +3,15 @@ from .blocks import SpatialResBlock
class SpatialResNet(nn.Module):
def __init__(self, n_channels, n_classes):
def __init__(self, channels, classes):
super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=n_channels, out_channels=16, downsample=False)
self.block1 = SpatialResBlock(in_channels=channels, out_channels=16, downsample=False)
self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, n_classes)
self.fc = nn.Linear(64, classes)
def forward(self, x):
out = self.block1(x)
......
......@@ -25,11 +25,13 @@ def test(model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
test_loss += F.cross_entropy(output, target, reduction='sum').item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
\ No newline at end of file
100. * correct / len(test_loader.dataset)))
return correct / len(test_loader.dataset)