Verified Commit 20c01fcc authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Allow variable number of classes

parent bb306f3d
...@@ -3,7 +3,7 @@ from .blocks import SpatialResBlock ...@@ -3,7 +3,7 @@ from .blocks import SpatialResBlock
class SpatialResNet(nn.Module): class SpatialResNet(nn.Module):
def __init__(self): def __init__(self, n_classes):
super(SpatialResNet, self).__init__() super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=1, out_channels=16, downsample=False) self.block1 = SpatialResBlock(in_channels=1, out_channels=16, downsample=False)
...@@ -11,7 +11,7 @@ class SpatialResNet(nn.Module): ...@@ -11,7 +11,7 @@ class SpatialResNet(nn.Module):
self.block3 = SpatialResBlock(in_channels=32, out_channels=64) self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1) self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, 10) self.fc = nn.Linear(64, n_classes)
def forward(self, x): def forward(self, x):
out = self.block1(x) out = self.block1(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment