Verified Commit b0a76158 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Make sure to batch norm after downsampling convolution

parent b19163d7
...@@ -20,6 +20,7 @@ class SpatialResBlock(nn.Module): ...@@ -20,6 +20,7 @@ class SpatialResBlock(nn.Module):
if downsample or in_channels != out_channels: if downsample or in_channels != out_channels:
stride = 2 if downsample else 1 stride = 2 if downsample else 1
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False) self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=0, bias=False)
self.bn_ds = nn.BatchNorm2d(out_channels)
else: else:
self.downsampler = None self.downsampler = None
...@@ -33,6 +34,7 @@ class SpatialResBlock(nn.Module): ...@@ -33,6 +34,7 @@ class SpatialResBlock(nn.Module):
if self.downsampler is not None: if self.downsampler is not None:
residual = self.downsampler(x) residual = self.downsampler(x)
residual = self.bn_ds(residual)
else: else:
residual = x residual = x
...@@ -58,6 +60,7 @@ class JpegResBlock(nn.Module): ...@@ -58,6 +60,7 @@ class JpegResBlock(nn.Module):
if spatial_resblock.downsampler is not None: if spatial_resblock.downsampler is not None:
self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down) self.downsampler = jpeg_layers.Conv2d(spatial_resblock.downsampler, J_down)
self.bn_ds = jpeg_layers.BatchNorm(spatial_resblock.bn_ds)
else: else:
self.downsampler = None self.downsampler = None
...@@ -78,6 +81,7 @@ class JpegResBlock(nn.Module): ...@@ -78,6 +81,7 @@ class JpegResBlock(nn.Module):
if self.downsampler is not None: if self.downsampler is not None:
residual = self.downsampler(x) residual = self.downsampler(x)
residual = self.bn_ds(residual)
else: else:
residual = x residual = x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment