Make sure to batch norm after downsampling convolution

parent b19163d7
......@@ -20,6 +20,7 @@ class SpatialResBlock(nn.Module):
if downsample or in_channels != out_channels:
stride = 2 if downsample else 1
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
self.bn_ds = nn.BatchNorm2d(out_channels)
else:
self.downsampler = None
......@@ -33,6 +34,7 @@ class SpatialResBlock(nn.Module):
if self.downsampler is not None:
residual = self.downsampler(x)
residual = self.bn_ds(residual)
else:
residual = x
......@@ -58,6 +60,7 @@ class JpegResBlock(nn.Module):
if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
self.bn_ds = jpeg_layers.BatchNorm(spatial_resblock.bn_ds)
else:
self.downsampler = None
......@@ -78,6 +81,7 @@ class JpegResBlock(nn.Module):
if self.downsampler is not None:
residual = self.downsampler(x)
residual = self.bn_ds(residual)
else:
residual = x
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment