Use apxrelu from the layers package

parent e249d6e7
......@@ -9,28 +9,6 @@ from jpeg_codec import D_n, Z, S_i, encode, decode
device = torch.device('cpu')
class AppxReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(AppxReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()]).to(device)
self.register_buffer('C_n', C_n)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch')
def appx_relu(self, x):
appx_im = self.annm_op(x, backend='torch')
appx_im[appx_im < 0] = 0
return appx_im
def forward(self, input):
appx = self.appx_relu(input)
return appx
def rmse_error(a, b):
return torch.sqrt(torch.mean((a - b)**2))
......@@ -62,8 +40,8 @@ spatial_relu = torch.nn.ReLU(inplace=False).to(device)
for f in range(15):
print('Processsing spatial frequency {}'.format(f))
jpeg_relu = jpeg_layers.ReLU(f).to(device)
appx_relu = AppxReLU(f).to(device)
jpeg_relu = jpeg_layers.ASMReLU(f).to(device)
appx_relu = jpeg_layers.APXReLU(f).to(device)
for b in range(args.batches):
im = torch.rand(args.batch_size, 1, 4, 4).to(device) * 2 - 1
......@@ -76,7 +54,7 @@ for f in range(15):
apx_relu = appx_relu(im_jpeg)
annm_im = decode(annm_relu, device=device)
apx_im = apx_relu.view(-1, 1, 8, 8)
apx_im = decode(apx_relu, device=device)
annm_errors[f] += rmse_error(annm_im, true_relu)
appx_errors[f] += rmse_error(apx_im, true_relu)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment