...
 
Commits (9)
......@@ -12,6 +12,7 @@ paper/*.log
paper/*.pdf
paper/*.run.xml
paper/*.synctex.gz
paper/**/*.eps
# User-specific stuff
......
import models
import torch
import torch.optim as optim
import argparse
import data
device = torch.device('cuda')
parser = argparse.ArgumentParser()
parser.add_argument('--models', type=int, help='Number of models to use')
parser.add_argument('--epochs', type=int, help='Number of epochs to train for')
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()
spatial_accuracies = 0
asm_accuracies = torch.zeros(15)
apx_accuracies = torch.zeros(15)
for m in range(args.models):
print('Train spatial model {}/{}'.format(m+1, args.models))
spatial_dataset = data.spatial_dataset_map[args.dataset](args.batch_size, args.data)
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes']).to(device)
optimizer = optim.Adam(spatial_model.parameters())
for e in range(args.epochs):
models.train(spatial_model, device, spatial_dataset[0], optimizer, e)
models.test(spatial_model, device, spatial_dataset[1])
acc = models.test(spatial_model, device, spatial_dataset[1])
spatial_accuracies += acc
for f in range(15):
print('Convert model to ASM JPEG with {} spatial frequencies'.format(f))
jpeg_model = models.JpegResNet(spatial_model, n_freqs=f).to(device)
jpeg_model.explode_all()
print('Test ASM JPEG model')
jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data)
acc = models.test(jpeg_model, device, jpeg_dataset[1])
asm_accuracies[f] += acc
print('Convert model to APX JPEG with {} spatial frequencies'.format(f))
jpeg_model = models.JpegResNetApx(spatial_model, n_freqs=f).to(device)
jpeg_model.explode_all()
print('Test APX JPEG model')
acc = models.test(jpeg_model, device, jpeg_dataset[1])
apx_accuracies[f] += acc
spatial_accuracies /= args.models
asm_accuracies /= args.models
apx_accuracies /= args.models
with open('{}_relu_accuracy.csv'.format(args.dataset), 'w') as f:
f.write('Spatial, ASM, APX\n')
for i in range(15):
f.write('{}, {}, {}\n'.format(spatial_accuracies, asm_accuracies[i], apx_accuracies[i]))
......@@ -9,28 +9,6 @@ from jpeg_codec import D_n, Z, S_i, encode, decode
device = torch.device('cpu')
class AppxReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(AppxReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()]).to(device)
self.register_buffer('C_n', C_n)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
def appx_relu(self, x):
appx_im = self.annm_op(x, backend='torch')
appx_im[appx_im < 0] = 0
return appx_im
def forward(self, input):
appx = self.appx_relu(input)
return appx
def rmse_error(a, b):
return torch.sqrt(torch.mean((a - b)**2))
......@@ -58,12 +36,12 @@ args = parser.parse_args()
annm_errors = np.zeros(15)
appx_errors = np.zeros(15)
spatial_relu = torch.nn.ReLU(inplace=True).to(device)
spatial_relu = torch.nn.ReLU(inplace=False).to(device)
for f in range(15):
print('Processsing spatial frequency {}'.format(f))
jpeg_relu = jpeg_layers.ReLU(f).to(device)
appx_relu = AppxReLU(f).to(device)
jpeg_relu = jpeg_layers.ASMReLU(f).to(device)
appx_relu = jpeg_layers.APXReLU(f).to(device)
for b in range(args.batches):
im = torch.rand(args.batch_size, 1, 4, 4).to(device) * 2 - 1
......@@ -76,16 +54,16 @@ for f in range(15):
apx_relu = appx_relu(im_jpeg)
annm_im = decode(annm_relu, device=device)
apx_im = apx_relu.view(-1, 1, 8, 8)
apx_im = decode(apx_relu, device=device)
annm_errors[f] += rmse_error(annm_im, true_relu)
appx_errors[f] += rmse_error(apx_im, true_relu)
annm_errors /= args.batches * args.batch_size
appx_errors /= args.batches * args.batch_size
annm_errors /= args.batches
appx_errors /= args.batches
with open(args.output, 'w') as f:
f.write('ANNM,APPX\n')
f.write('ASM,APX\n')
for i in range(15):
f.write('{},{}\n'.format(annm_errors[i], appx_errors[i]))
import torch
import opt_einsum as oe
from jpeg_codec import D_n, D, Z, S_i, S
from jpeg_codec import encode
class ReLU(torch.nn.modules.Module):
class ASMReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(ReLU, self).__init__()
super(ASMReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
......@@ -22,7 +23,7 @@ class ReLU(torch.nn.modules.Module):
self.hsm_op.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(ReLU, self)._apply(fn)
s = super(ASMReLU, self)._apply(fn)
s.make_masking_ops()
return s
......@@ -39,3 +40,37 @@ class ReLU(torch.nn.modules.Module):
annm = self.annm(input)
out_comp = self.half_spatial_mask(input, annm)
return out_comp
class APXReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(APXReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
C = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S()])
self.register_buffer('C', C)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
self.compress_op = oe.contract_expression('ijk,tmxyij->tmxyk', self.C, [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
self.compress_op.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(APXReLU, self)._apply(fn)
s.make_masking_ops()
return s
def appx_relu(self, x):
appx_im = self.annm_op(x, backend='torch')
appx_im[appx_im < 0] = 0
return appx_im
def forward(self, input):
appx = self.appx_relu(input)
appx = self.compress_op(appx, backend='torch')
return appx
......@@ -45,7 +45,7 @@ class SpatialResBlock(nn.Module):
class JpegResBlock(nn.Module):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out, relu_layer=jpeg_layers.ASMReLU):
super(JpegResBlock, self).__init__()
J_down = (J_out[0], J_in[1])
......@@ -56,7 +56,7 @@ class JpegResBlock(nn.Module):
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
self.relu = relu_layer(n_freqs=n_freqs)
if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
......
import torch.nn as nn
from .blocks import JpegResBlock
from jpeg_layers import AvgPool
from jpeg_layers import AvgPool, ASMReLU, APXReLU
from jpeg_codec import codec
class JpegResNet(nn.Module):
def __init__(self, spatial_model, n_freqs):
def __init__(self, spatial_model, n_freqs, relu_layer=ASMReLU):
super(JpegResNet, self).__init__()
J_32 = codec((32, 32))
J_16 = codec((16, 16))
J_8 = codec((8, 8))
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8)
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32, relu_layer=relu_layer)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16, relu_layer=relu_layer)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8, relu_layer=relu_layer)
self.averagepooling = AvgPool()
self.fc = spatial_model.fc
......@@ -40,3 +40,8 @@ class JpegResNet(nn.Module):
class JpegResNetExact(JpegResNet):
def __init__(self, spatial_model):
super(JpegResNetExact, self).__init__(spatial_model, 14)
class JpegResNetApx(JpegResNet):
def __init__(self, spatial_model, n_freqs):
super(JpegResNetApx, self).__init__(spatial_model, n_freqs, APXReLU)
ASM,APX
0.351330280720293,0.370587906224728
0.304883404178619,0.339589755077958
0.250888954913914,0.303201176722944
0.183227186731398,0.259644525761902
0.133206807076111,0.224467879328728
0.087671215625405,0.185548062829822
0.056020900195502,0.147140650034249
0.038041077943556,0.114263159455061
0.019215256675147,0.075466101892516
0.011677978458144,0.05492251554165
0.006982008791193,0.038950518101715
0.003444718584642,0.024361371224523
0.001145424387555,0.011425326326424
0.00024486728348,0.003880036611981
1.33567197207896E-07,9.73571013530972E-08
#!/usr/bin/gnuplot -c
if (ARG2 eq 'eps') {
set terminal postscript enhanced color font "Helvetica, 30" eps
set size 3,2
} else {
set terminal ARG2
}
set datafile separator ','
set ylabel 'Average RMSE'
set grid ytics
set key tmargin horizontal
set xlabel 'Number of Spatial Frequencies'
set xrange ['1':'15']
set xtics 1
set output "relu_blocks.eps"
plot ARG1 using ($0+1):2 with linespoints linewidth 8 pointsize 5 title columnhead, \
ARG1 using ($0+1):1 with linespoints linewidth 8 pointsize 5 title columnhead
if (ARG2 ne 'eps') pause -1 "Press any key to continue..."
\ No newline at end of file