...
 
Commits (3)
import argparse
import jpeg_layers
import torch
import torch.nn
import numpy as np
import opt_einsum as oe
from jpeg_codec import D_n, Z, S_i, encode, decode
from scipy.misc import imresize
device = torch.device('cuda')
class AppxReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(AppxReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()]).to(device)
self.register_buffer('C_n', C_n)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
def appx_relu(self, x):
appx_im = self.annm_op(x, backend='torch')
appx_im[appx_im < 0] = 0
return appx_im
def forward(self, input):
appx = self.appx_relu(input)
return appx
def rmse_error(a, b):
return torch.sqrt(torch.mean((a - b)**2))
def double_size_tensor(shape):
op = torch.zeros((shape[0], shape[1], shape[0] * 2, shape[1] * 2)).to(device)
for i in range(0, shape[0]):
for j in range(0, shape[1]):
for u in range(0, shape[0] * 2):
for v in range(0, shape[1] * 2):
if i == u // 2 and j == v // 2:
op[i, j, u, v] = 1
return op
doubling_tensor = double_size_tensor((4, 4))
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, help='Number of samples per batch')
parser.add_argument('--batches', type=int, help='Number of batches')
parser.add_argument('--output', help='Output CSV file name')
args = parser.parse_args()
annm_errors = np.zeros(15)
appx_errors = np.zeros(15)
spatial_relu = torch.nn.ReLU(inplace=True).to(device)
for f in range(15):
print('Processsing spatial frequency {}'.format(f))
jpeg_relu = jpeg_layers.ReLU(f).to(device)
appx_relu = AppxReLU(f).to(device)
for b in range(args.batches):
im = torch.rand(args.batch_size, 1, 4, 4).to(device) * 2 - 1
im = torch.einsum('ijuv,ncij->ncuv', [doubling_tensor, im])
im_jpeg = encode(im, device=device)
true_relu = spatial_relu(im)
annm_relu = jpeg_relu(im_jpeg)
apx_relu = appx_relu(im_jpeg)
annm_im = decode(annm_relu, device=device)
annm_errors[f] += rmse_error(annm_im, true_relu)
appx_errors[f] += rmse_error(apx_relu, true_relu)
annm_errors /= args.batches * args.batch_size
appx_errors /= args.batches * args.batch_size
with open(args.output, 'w') as f:
f.write('ANNM,APPX\n')
for i in range(15):
f.write('{},{}\n'.format(annm_errors[i], appx_errors[i]))
......@@ -25,8 +25,8 @@ def encode(batch, block_size=(8, 8), device=None):
def decode(batch, device=None):
block_size = np.sqrt(batch.shape[4])
image_size = (batch.shape[2] * block_size, batch.shape[3] * block_size)
block_size = int(np.sqrt(batch.shape[4]))
image_size = (int(batch.shape[2] * block_size), int(batch.shape[3] * block_size))
_, J_i = codec(image_size, (block_size, block_size))
......
......@@ -2,3 +2,4 @@ opt_einsum
torch
torchvision
numpy
scipy
\ No newline at end of file