Verified Commit b19163d7 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Allow exploded convolutions to be precomputed for faster testing

parent 42ad5fb3
...@@ -33,6 +33,7 @@ for m in range(args.models): ...@@ -33,6 +33,7 @@ for m in range(args.models):
print('Convert model to JPEG') print('Convert model to JPEG')
jpeg_model = models.JpegResNetExact(spatial_model).to(device) jpeg_model = models.JpegResNetExact(spatial_model).to(device)
jpeg_model.explode_all()
print('Test JPEG model') print('Test JPEG model')
jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data) jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data)
......
...@@ -3,9 +3,9 @@ import opt_einsum as oe ...@@ -3,9 +3,9 @@ import opt_einsum as oe
import numpy as np import numpy as np
class Conv2dBase(torch.nn.modules.Module): class Conv2d(torch.nn.modules.Module):
def __init__(self, conv_spatial, J): def __init__(self, conv_spatial, J):
super(Conv2dBase, self).__init__() super(Conv2d, self).__init__()
self.stride = conv_spatial.stride self.stride = conv_spatial.stride
self.weight = conv_spatial.weight self.weight = conv_spatial.weight
...@@ -19,6 +19,8 @@ class Conv2dBase(torch.nn.modules.Module): ...@@ -19,6 +19,8 @@ class Conv2dBase(torch.nn.modules.Module):
self.make_apply_op() self.make_apply_op()
self.jpeg_op = None
def make_apply_op(self): def make_apply_op(self):
input_shape = [0, self.weight.shape[1], *self.J_i.shape[0:3]] input_shape = [0, self.weight.shape[1], *self.J_i.shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J_i.shape[0:3], *self.J.shape[0:2]] jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J_i.shape[0:3], *self.J.shape[0:2]]
...@@ -27,7 +29,7 @@ class Conv2dBase(torch.nn.modules.Module): ...@@ -27,7 +29,7 @@ class Conv2dBase(torch.nn.modules.Module):
self.apply_conv.evaluate_constants(backend='torch') self.apply_conv.evaluate_constants(backend='torch')
def _apply(self, fn): def _apply(self, fn):
s = super(Conv2dBase, self)._apply(fn) s = super(Conv2d, self)._apply(fn)
s.make_apply_op() s.make_apply_op()
return s return s
...@@ -41,20 +43,13 @@ class Conv2dBase(torch.nn.modules.Module): ...@@ -41,20 +43,13 @@ class Conv2dBase(torch.nn.modules.Module):
return jpeg_op return jpeg_op
def explode_pre(self):
class Conv2dRT(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2dRT, self).__init__(conv_spatial, J)
def forward(self, input):
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
class Conv2dPre(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2dPre, self).__init__(conv_spatial, J)
self.jpeg_op = self.explode() self.jpeg_op = self.explode()
def forward(self, input): def forward(self, input):
return self.apply_conv(self.jpeg_op, input, backend='torch') if self.jpeg_op is not None:
jpeg_op = self.jpeg_op
else:
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
...@@ -48,8 +48,8 @@ class JpegResBlock(nn.Module): ...@@ -48,8 +48,8 @@ class JpegResBlock(nn.Module):
J_down = (J_out[0], J_in[1]) J_down = (J_out[0], J_in[1])
self.conv1 = jpeg_layers.Conv2dRT(spatial_resblock.conv1, J_down) self.conv1 = jpeg_layers.Conv2d(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2dRT(spatial_resblock.conv2, J_out) self.conv2 = jpeg_layers.Conv2d(spatial_resblock.conv2, J_out)
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1) self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2) self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
...@@ -57,10 +57,17 @@ class JpegResBlock(nn.Module): ...@@ -57,10 +57,17 @@ class JpegResBlock(nn.Module):
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs) self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
if spatial_resblock.downsampler is not None: if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2dRT(spatial_resblock.downsampler, J_down) self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
else: else:
self.downsampler = None self.downsampler = None
def explode_all(self):
self.conv1.explode_pre()
self.conv2.explode_pre()
if self.downsampler is not None:
self.downsampler.explode_pre()
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
......
...@@ -19,6 +19,11 @@ class JpegResNet(nn.Module): ...@@ -19,6 +19,11 @@ class JpegResNet(nn.Module):
self.averagepooling = AvgPool() self.averagepooling = AvgPool()
self.fc = spatial_model.fc self.fc = spatial_model.fc
def explode_all(self):
self.block1.explode_all()
self.block2.explode_all()
self.block3.explode_all()
def forward(self, x): def forward(self, x):
out = self.block1(x) out = self.block1(x)
out = self.block2(out) out = self.block2(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment