Allow exploded convolutions to be precomputed for faster testing

parent 42ad5fb3
......@@ -33,6 +33,7 @@ for m in range(args.models):
print('Convert model to JPEG')
jpeg_model = models.JpegResNetExact(spatial_model).to(device)
jpeg_model.explode_all()
print('Test JPEG model')
jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data)
......
......@@ -3,9 +3,9 @@ import opt_einsum as oe
import numpy as np
class Conv2dBase(torch.nn.modules.Module):
class Conv2d(torch.nn.modules.Module):
def __init__(self, conv_spatial, J):
super(Conv2dBase, self).__init__()
super(Conv2d, self).__init__()
self.stride = conv_spatial.stride
self.weight = conv_spatial.weight
......@@ -19,6 +19,8 @@ class Conv2dBase(torch.nn.modules.Module):
self.make_apply_op()
self.jpeg_op = None
def make_apply_op(self):
input_shape = [0, self.weight.shape[1], *self.J_i.shape[0:3]]
jpeg_op_shape = [self.weight.shape[0], self.weight.shape[1], *self.J_i.shape[0:3], *self.J.shape[0:2]]
......@@ -27,7 +29,7 @@ class Conv2dBase(torch.nn.modules.Module):
self.apply_conv.evaluate_constants(backend='torch')
def _apply(self, fn):
s = super(Conv2dBase, self)._apply(fn)
s = super(Conv2d, self)._apply(fn)
s.make_apply_op()
return s
......@@ -41,20 +43,13 @@ class Conv2dBase(torch.nn.modules.Module):
return jpeg_op
class Conv2dRT(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2dRT, self).__init__(conv_spatial, J)
def forward(self, input):
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
class Conv2dPre(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2dPre, self).__init__(conv_spatial, J)
def explode_pre(self):
self.jpeg_op = self.explode()
def forward(self, input):
return self.apply_conv(self.jpeg_op, input, backend='torch')
if self.jpeg_op is not None:
jpeg_op = self.jpeg_op
else:
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
......@@ -48,8 +48,8 @@ class JpegResBlock(nn.Module):
J_down = (J_out[0], J_in[1])
self.conv1 = jpeg_layers.Conv2dRT(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2dRT(spatial_resblock.conv2, J_out)
self.conv1 = jpeg_layers.Conv2d(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2d(spatial_resblock.conv2, J_out)
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
......@@ -57,10 +57,17 @@ class JpegResBlock(nn.Module):
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2dRT(spatial_resblock.downsampler, J_down)
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
else:
self.downsampler = None
def explode_all(self):
self.conv1.explode_pre()
self.conv2.explode_pre()
if self.downsampler is not None:
self.downsampler.explode_pre()
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
......
......@@ -19,6 +19,11 @@ class JpegResNet(nn.Module):
self.averagepooling = AvgPool()
self.fc = spatial_model.fc
def explode_all(self):
self.block1.explode_all()
self.block2.explode_all()
self.block3.explode_all()
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment