Verified Commit 5726be9f authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Modify model code to be device independent

parent 84dd7d06
...@@ -11,8 +11,11 @@ class BatchNorm(torch.nn.modules.Module): ...@@ -11,8 +11,11 @@ class BatchNorm(torch.nn.modules.Module):
self.gamma = bn.weight self.gamma = bn.weight
self.beta = bn.bias self.beta = bn.bias
self.gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1) gamma_final = (self.gamma / torch.sqrt(self.var)).view(1, self.gamma.shape[0], 1, 1, 1)
self.beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1) self.register_buffer('gamma_final', gamma_final)
beta_final = (self.beta - (self.gamma * self.mean) / torch.sqrt(self.var)).view(1, self.beta.shape[0], 1, 1)
self.register_buffer('beta_final', beta_final)
def forward(self, input): def forward(self, input):
input = input * self.gamma_final input = input * self.gamma_final
......
...@@ -11,29 +11,40 @@ class Conv2dBase(torch.nn.modules.Module): ...@@ -11,29 +11,40 @@ class Conv2dBase(torch.nn.modules.Module):
self.weight = conv_spatial.weight self.weight = conv_spatial.weight
self.padding = conv_spatial.padding self.padding = conv_spatial.padding
self.J = J self.register_buffer('J', J[0])
self.J_batched = self.J[1].contiguous().view(np.prod(self.J[1].shape[0:3]), 1, *self.J[1].shape[3:5]) self.register_buffer('J_i', J[1])
input_shape = [0, self.weight.shape[1], *self.J[1].shape[0:3]] J_batched = self.J_i.contiguous().view(np.prod(self.J_i.shape[0:3]), 1, *self.J_i.shape[3:5])
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J[1].shape[0:3], *self.J[0].shape[0:2]] self.register_buffer('J_batched', J_batched)
self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc', jpeg_op_shape, self.J[0], input_shape, constants=[1], optimize='optimal') self.make_apply_op()
def make_apply_op(self):
input_shape = [0, self.weight.shape[1], *self.J_i.shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J_i.shape[0:3], *self.J.shape[0:2]]
self.apply_conv = oe.contract_expression('mnxyksr,srabc,tnxyk->tmabc', jpeg_op_shape, self.J, input_shape, constants=[1], optimize='optimal')
self.apply_conv.evaluate_constants(backend='torch') self.apply_conv.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(Conv2dBase, self)._apply(fn)
s.make_apply_op()
return s
def explode(self): def explode(self):
out_channels = self.weight.shape[0] out_channels = self.weight.shape[0]
in_channels = self.weight.shape[1] in_channels = self.weight.shape[1]
jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride) jpeg_op = torch.nn.functional.conv2d(self.J_batched, self.weight.view(out_channels * in_channels, 1, self.weight.shape[2], self.weight.shape[3]), padding=self.padding, stride=self.stride)
jpeg_op = jpeg_op.permute(1, 0, 2, 3) jpeg_op = jpeg_op.permute(1, 0, 2, 3)
jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J[1].shape[0:3], *(np.array(self.J[1].shape[3:5]) // self.stride)) jpeg_op = jpeg_op.view(out_channels, in_channels, *self.J_i.shape[0:3], *(np.array(self.J_i.shape[3:5]) // self.stride))
return jpeg_op return jpeg_op
class Conv2d(Conv2dBase): class Conv2dRT(Conv2dBase):
def __init__(self, conv_spatial, J): def __init__(self, conv_spatial, J):
super(Conv2d, self).__init__(conv_spatial, J) super(Conv2dRT, self).__init__(conv_spatial, J)
def forward(self, input): def forward(self, input):
jpeg_op = self.explode() jpeg_op = self.explode()
......
...@@ -6,15 +6,26 @@ from jpeg_codec import D_n, D, Z, S_i, S ...@@ -6,15 +6,26 @@ from jpeg_codec import D_n, D, Z, S_i, S
class ReLU(torch.nn.modules.Module): class ReLU(torch.nn.modules.Module):
def __init__(self, n_freqs): def __init__(self, n_freqs):
super(ReLU, self).__init__() super(ReLU, self).__init__()
self.C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()]) C_n = torch.einsum('ijab,abg,gk->ijk', [D_n(n_freqs), Z(), S_i()])
self.Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()]) self.register_buffer('C_n', C_n)
Hm = torch.einsum('ijab,ijuv,abg,gk,uvh,hl->ijkl', [D(), D(), Z(), S_i(), Z(), S()])
self.register_buffer('Hm', Hm)
self.make_masking_ops()
def make_masking_ops(self):
self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal') self.annm_op = oe.contract_expression('ijk,tmxyk->tmxyij', self.C_n, [0, 0, 0, 0, 64], constants=[0], optimize='optimal')
self.annm_op.evaluate_constants(backend='torch') self.annm_op.evaluate_constants(backend='torch')
self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal') self.hsm_op = oe.contract_expression('ijkl,tmxyk,tmxyij->tmxyl', self.Hm, [0, 0, 0, 0, 64], [0, 0, 0, 0, 8, 8], constants=[0], optimize='optimal')
self.hsm_op.evaluate_constants(backend='torch') self.hsm_op.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(ReLU, self)._apply(fn)
s.make_masking_ops()
return s
def annm(self, x): def annm(self, x):
appx_im = self.annm_op(x, backend='torch') appx_im = self.annm_op(x, backend='torch')
mask = torch.zeros_like(appx_im) mask = torch.zeros_like(appx_im)
......
...@@ -45,8 +45,8 @@ class JpegResBlock(nn.Module): ...@@ -45,8 +45,8 @@ class JpegResBlock(nn.Module):
J_down = (J_out[0], J_in[1]) J_down = (J_out[0], J_in[1])
self.conv1 = jpeg_layers.Conv2d(spatial_resblock.conv1, J_down) self.conv1 = jpeg_layers.Conv2dRT(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2d(spatial_resblock.conv2, J_out) self.conv2 = jpeg_layers.Conv2dRT(spatial_resblock.conv2, J_out)
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1) self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2) self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
...@@ -55,7 +55,7 @@ class JpegResBlock(nn.Module): ...@@ -55,7 +55,7 @@ class JpegResBlock(nn.Module):
if spatial_resblock.downsample: if spatial_resblock.downsample:
self.downsample = True self.downsample = True
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down) self.downsampler = jpeg_layers.Conv2dRT(spatial_resblock.downsampler, J_down)
else: else:
self.downsample = False self.downsample = False
......
...@@ -34,4 +34,4 @@ class JpegResNet(nn.Module): ...@@ -34,4 +34,4 @@ class JpegResNet(nn.Module):
class JpegResNetExact(JpegResNet): class JpegResNetExact(JpegResNet):
def __init__(self, spatial_model): def __init__(self, spatial_model):
super(JpegResNet, self).__init__(spatial_model, 14) super(JpegResNetExact, self).__init__(spatial_model, 14)
...@@ -3,15 +3,15 @@ from .blocks import SpatialResBlock ...@@ -3,15 +3,15 @@ from .blocks import SpatialResBlock
class SpatialResNet(nn.Module): class SpatialResNet(nn.Module):
def __init__(self, n_channels, n_classes): def __init__(self, channels, classes):
super(SpatialResNet, self).__init__() super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=n_channels, out_channels=16, downsample=False) self.block1 = SpatialResBlock(in_channels=channels, out_channels=16, downsample=False)
self.block2 = SpatialResBlock(in_channels=16, out_channels=32) self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
self.block3 = SpatialResBlock(in_channels=32, out_channels=64) self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1) self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, n_classes) self.fc = nn.Linear(64, classes)
def forward(self, x): def forward(self, x):
out = self.block1(x) out = self.block1(x)
......
...@@ -32,4 +32,6 @@ def test(model, device, test_loader): ...@@ -32,4 +32,6 @@ def test(model, device, test_loader):
test_loss /= len(test_loader.dataset) test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset))) 100. * correct / len(test_loader.dataset)))
\ No newline at end of file
return correct / len(test_loader.dataset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment