Verified Commit c7ce7cb8 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add model definitions

parent c76ec438
import torch.nn as nn
import jpeg_layers
class SpatialResBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample=True):
super(SpatialResBlock, self).__init__()
self.downsample = downsample
stride = 2 if downsample else 1
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
if downsample:
self.downsampler = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2, padding=0, bias=False)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsampler(x)
else:
residual = x
out += residual
out = self.relu(out)
return out
class JpegResBlock(nn.Module):
def __init__(self, spatial_resblock, n_freqs, J_in, J_out):
super(JpegResBlock, self).__init__()
J_down = (J_out[0], J_in[1])
self.conv1 = jpeg_layers.Conv2d_a(spatial_resblock.conv1, J_down)
self.conv2 = jpeg_layers.Conv2d_a(spatial_resblock.conv2, J_out)
self.bn1 = jpeg_layers.BatchNorm(spatial_resblock.bn1)
self.bn2 = jpeg_layers.BatchNorm(spatial_resblock.bn2)
self.relu = jpeg_layers.ReLU(n_freqs=n_freqs)
if spatial_resblock.downsample:
self.downsample = True
self.downsampler = jpeg_layers.Conv2d_a(spatial_resblock.downsampler, J_down)
else:
self.downsample = False
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsampler(x)
else:
residual = x
out += residual
out = self.relu(out)
return out
import torch.nn as nn
from .blocks import JpegResBlock
from jpeg_layers import AvgPool
from jpeg_codec import codec
class JpegResNet(nn.Module):
def __init__(self, spatial_model, n_freqs):
super(JpegResNet, self).__init__()
J_32 = codec((32, 32))
J_16 = codec((16, 16))
J_8 = codec((8, 8))
self.block1 = JpegResBlock(spatial_model.block1, n_freqs=n_freqs, J_in=J_32, J_out=J_32)
self.block2 = JpegResBlock(spatial_model.block2, n_freqs=n_freqs, J_in=J_32, J_out=J_16)
self.block3 = JpegResBlock(spatial_model.block3, n_freqs=n_freqs, J_in=J_16, J_out=J_8)
self.averagepooling = AvgPool()
self.fc = spatial_model.fc
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
class JpegResNetExact(JpegResNet):
def __init__(self, spatial_model):
super(JpegResNet, self).__init__(spatial_model, 14)
import torch.nn as nn
from .blocks import SpatialResBlock
class SpatialResNet(nn.Module):
def __init__(self):
super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=1, out_channels=16, downsample=False)
self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, 10)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
out = self.block3(out)
out = self.averagepooling(out)
out = out.view(x.size(0), -1)
out = self.fc(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment