Verified Commit ce5115fc authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Fix naming convention

parent 9dd38e1e
......@@ -3,9 +3,9 @@ import opt_einsum as oe
import numpy as np
class Conv2d_base(torch.nn.modules.Module):
class Conv2dBase(torch.nn.modules.Module):
def __init__(self, conv_spatial, J):
super(Conv2d_base, self).__init__()
super(Conv2dBase, self).__init__()
self.stride = conv_spatial.stride
self.weight = conv_spatial.weight
......@@ -31,18 +31,18 @@ class Conv2d_base(torch.nn.modules.Module):
return jpeg_op
class Conv2d_a(Conv2d_base):
class Conv2d(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2d_a, self).__init__(conv_spatial, J)
super(Conv2d, self).__init__(conv_spatial, J)
def forward(self, input):
jpeg_op = self.explode()
return self.apply_conv(jpeg_op, input, backend='torch')
class Conv2d_b(Conv2d_base):
class Conv2dPre(Conv2dBase):
def __init__(self, conv_spatial, J):
super(Conv2d_b, self).__init__(conv_spatial, J)
super(Conv2dPre, self).__init__(conv_spatial, J)
self.jpeg_op = self.explode()
def forward(self, input):
......
......@@ -45,8 +45,8 @@ class JpegResBlock(nn.Module):
J_down = (J_out[0], J_in[1])
self.conv1 = jpeg_layers.Conv2d_a(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2d_a(spatial_resblock.conv2, J_out)
self.conv1 = jpeg_layers.Conv2d(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2d(spatial_resblock.conv2, J_out)
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
......@@ -55,7 +55,7 @@ class JpegResBlock(nn.Module):
if spatial_resblock.downsample:
self.downsample = True
self.downsampler = jpeg_layers.Conv2d_a(spatial_resblock.downsampler, J_down)
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
else:
self.downsample = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment