Add relu blockwise experiment

parent f5792788
import argparse
import jpeg_layers
import torch
import torch.nn
import numpy as np
import opt_einsum as oe
from jpeg_codec import D_n, Z, S_i, encode, decode
from scipy.misc import imresize
class AppxReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(AppxReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.register_buffer('C_n', C_n)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
def appx_relu(self, x):
appx_im = self.annm_op(x, backend='torch')
appx_im[appx_im < 0] = 0
return appx_im
def forward(self, input):
appx = self.appx_relu(input)
return appx
def rmse_error(a, b):
return torch.sqrt(torch.mean((a - b)**2))
parser = argparse.ArgumentParser()
parser.add_argument('--num_samples', type=int, help='Number of samples to use for each spatial frequency')
parser.add_argument('--output', help='Output CSV file name')
args = parser.parse_args()
annm_errors = np.zeros(15)
appx_errors = np.zeros(15)
device = torch.device('cuda')
spatial_relu = torch.nn.ReLU(inplace=True)
for f in range(15):
print('Processsing spatial frequency {}'.format(f))
jpeg_relu = jpeg_layers.ReLU(f)
appx_relu = AppxReLU(f)
for n in range(args.num_samples):
im = np.random.rand(4, 4) * 2 - 1
im = imresize(im, (8, 8), interp='nearest', mode='F')
im = torch.Tensor(im).unsqueeze(0).unsqueeze(0)
im_jpeg = encode(im)
true_relu = spatial_relu(im)
annm_relu = jpeg_relu(im_jpeg)
apx_relu = appx_relu(im_jpeg)
annm_im = decode(annm_relu)
annm_errors[f] += rmse_error(annm_im, true_relu)
appx_errors[f] += rmse_error(apx_relu, true_relu)
annm_errors /= args.num_samples
appx_errors /= args.num_samples
with open(args.output, 'w') as f:
f.write('ANNM,APPX\n')
for i in range(15):
f.write('{},{}\n'.format(annm_errors[i], appx_errors[i]))
...@@ -2,3 +2,4 @@ opt_einsum ...@@ -2,3 +2,4 @@ opt_einsum
torch torch
torchvision torchvision
numpy numpy
scipy
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment