spatial_model.py 730 Bytes
Newer Older
Max Ehrlich's avatar
Max Ehrlich committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn
from .blocks import SpatialResBlock


class SpatialResNet(nn.Module):
    def __init__(self):
        super(SpatialResNet, self).__init__()

        self.block1 = SpatialResBlock(in_channels=1, out_channels=16, downsample=False)
        self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
        self.block3 = SpatialResBlock(in_channels=32, out_channels=64)

        self.averagepooling = nn.AvgPool2d(8, stride=1)
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)

        out = self.averagepooling(out)
        out = out.view(x.size(0), -1)

        out = self.fc(out)

        return out