Verified Commit 42ad5fb3 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Correct handling of blocks with different input and output channels

parent f70d43ed
......@@ -17,8 +17,11 @@ class SpatialResBlock(nn.Module):
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2, padding=0, bias=False)
if downsample or in_channels != out_channels:
stride = 2 if downsample else 1
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
else:
self.downsampler = None
def forward(self, x):
out = self.conv1(x)
......@@ -28,7 +31,7 @@ class SpatialResBlock(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
if self.downsampler is not None:
residual = self.downsampler(x)
else:
residual = x
......@@ -53,11 +56,10 @@ class JpegResBlock(nn.Module):
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
if spatial_resblock.downsample:
self.downsample = True
if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2dRT(spatial_resblock.downsampler, J_down)
else:
self.downsample = False
self.downsampler = None
def forward(self, x):
out = self.conv1(x)
......@@ -67,7 +69,7 @@ class JpegResBlock(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
if self.downsampler is not None:
residual = self.downsampler(x)
else:
residual = x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment