import torch.nn as nn from .blocks import SpatialResBlock class SpatialResNet(nn.Module): def __init__(self, n_channels, n_classes): super(SpatialResNet, self).__init__() self.block1 = SpatialResBlock(in_channels=n_channels, out_channels=16, downsample=False) self.block2 = SpatialResBlock(in_channels=16, out_channels=32) self.block3 = SpatialResBlock(in_channels=32, out_channels=64) self.averagepooling = nn.AvgPool2d(8, stride=1) self.fc = nn.Linear(64, n_classes) def forward(self, x): out = self.block1(x) out = self.block2(out) out = self.block3(out) out = self.averagepooling(out) out = out.view(x.size(0), -1) out = self.fc(out) return out