Verified Commit 42ad5fb3 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Correct handling of blocks with different input and output channels

parent f70d43ed
...@@ -17,8 +17,11 @@ class SpatialResBlock(nn.Module): ...@@ -17,8 +17,11 @@ class SpatialResBlock(nn.Module):
self.bn2 = nn.BatchNorm2d(out_channels) self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
if downsample: if downsample or in_channels != out_channels:
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2, padding=0, bias=False) stride = 2 if downsample else 1
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
else:
self.downsampler = None
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
...@@ -28,7 +31,7 @@ class SpatialResBlock(nn.Module): ...@@ -28,7 +31,7 @@ class SpatialResBlock(nn.Module):
out = self.conv2(out) out = self.conv2(out)
out = self.bn2(out) out = self.bn2(out)
if self.downsample: if self.downsampler is not None:
residual = self.downsampler(x) residual = self.downsampler(x)
else: else:
residual = x residual = x
...@@ -53,11 +56,10 @@ class JpegResBlock(nn.Module): ...@@ -53,11 +56,10 @@ class JpegResBlock(nn.Module):
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs) self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
if spatial_resblock.downsample: if spatial_resblock.downsampler is not None:
self.downsample = True
self.downsampler = jpeg_layers.Conv2dRT(spatial_resblock.downsampler, J_down) self.downsampler = jpeg_layers.Conv2dRT(spatial_resblock.downsampler, J_down)
else: else:
self.downsample = False self.downsampler = None
def forward(self, x): def forward(self, x):
out = self.conv1(x) out = self.conv1(x)
...@@ -67,7 +69,7 @@ class JpegResBlock(nn.Module): ...@@ -67,7 +69,7 @@ class JpegResBlock(nn.Module):
out = self.conv2(out) out = self.conv2(out)
out = self.bn2(out) out = self.bn2(out)
if self.downsample: if self.downsampler is not None:
residual = self.downsampler(x) residual = self.downsampler(x)
else: else:
residual = x residual = x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment