Verified Commit 20c01fcc authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Allow variable number of classes

parent bb306f3d
......@@ -3,7 +3,7 @@ from .blocks import SpatialResBlock
class SpatialResNet(nn.Module):
def __init__(self):
def __init__(self, n_classes):
super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=1, out_channels=16, downsample=False)
......@@ -11,7 +11,7 @@ class SpatialResNet(nn.Module):
self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, 10)
self.fc = nn.Linear(64, n_classes)
def forward(self, x):
out = self.block1(x)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment