Fix shape broadcasting bug

parent 90171997
......@@ -5,9 +5,8 @@ import torch.nn
import numpy as np
import opt_einsum as oe
from jpeg_codec import D_n, Z, S_i, encode, decode
from scipy.misc import imresize
device = torch.device('cuda')
device = torch.device('cpu')
class AppxReLU(torch.nn.modules.Module):
......@@ -77,9 +76,10 @@ for f in range(15):
apx_relu = appx_relu(im_jpeg)
annm_im = decode(annm_relu, device=device)
apx_im = apx_relu.view(-1, 1, 8, 8)
annm_errors[f] += rmse_error(annm_im, true_relu)
appx_errors[f] += rmse_error(apx_relu, true_relu)
appx_errors[f] += rmse_error(apx_im, true_relu)
annm_errors /= args.batches * args.batch_size
appx_errors /= args.batches * args.batch_size
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment