Faster processing using batched operations

parent 66175009
......@@ -7,11 +7,13 @@ import opt_einsum as oe
from jpeg_codec import D_n, Z, S_i, encode, decode
from scipy.misc import imresize
device = torch.device('cuda')
class AppxReLU(torch.nn.modules.Module):
def __init__(self, n_freqs):
super(AppxReLU, self).__init__()
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()]).to(device)
self.register_buffer('C_n', C_n)
self.make_masking_ops()
......@@ -34,41 +36,53 @@ def rmse_error(a, b):
return torch.sqrt(torch.mean((a - b)**2))
def double_size_tensor(shape):
op = torch.zeros((shape[0], shape[1], shape[0] * 2, shape[1] * 2)).to(device)
for i in range(0, shape[0]):
for j in range(0, shape[1]):
for u in range(0, shape[0] * 2):
for v in range(0, shape[1] * 2):
if i == u // 2 and j == v // 2:
op[i, j, u, v] = 1
return op
doubling_tensor = double_size_tensor((4, 4))
parser = argparse.ArgumentParser()
parser.add_argument('--num_samples', type=int, help='Number of samples to use for each spatial frequency')
parser.add_argument('--batch_size', type=int, help='Number of samples per batch')
parser.add_argument('--batches', type=int, help='Number of batches')
parser.add_argument('--output', help='Output CSV file name')
args = parser.parse_args()
annm_errors = np.zeros(15)
appx_errors = np.zeros(15)
device = torch.device('cuda')
spatial_relu = torch.nn.ReLU(inplace=True)
spatial_relu = torch.nn.ReLU(inplace=True).to(device)
for f in range(15):
print('Processsing spatial frequency {}'.format(f))
jpeg_relu = jpeg_layers.ReLU(f)
appx_relu = AppxReLU(f)
for n in range(args.num_samples):
im = np.random.rand(4, 4) * 2 - 1
im = imresize(im, (8, 8), interp='nearest', mode='F')
im = torch.Tensor(im).unsqueeze(0).unsqueeze(0)
jpeg_relu = jpeg_layers.ReLU(f).to(device)
appx_relu = AppxReLU(f).to(device)
im_jpeg = encode(im)
for b in range(args.batches):
im = torch.rand(args.batch_size, 1, 4, 4).to(device) * 2 - 1
im = torch.einsum('ijuv,ncij->ncuv', [doubling_tensor, im])
im_jpeg = encode(im, device=device)
true_relu = spatial_relu(im)
annm_relu = jpeg_relu(im_jpeg)
apx_relu = appx_relu(im_jpeg)
annm_im = decode(annm_relu)
annm_im = decode(annm_relu, device=device)
annm_errors[f] += rmse_error(annm_im, true_relu)
appx_errors[f] += rmse_error(apx_relu, true_relu)
annm_errors /= args.num_samples
appx_errors /= args.num_samples
annm_errors /= args.batches * args.batch_size
appx_errors /= args.batches * args.batch_size
with open(args.output, 'w') as f:
f.write('ANNM,APPX\n')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment