Verified Commit 5ab3c802 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Fix label datatype

parent 5726be9f
......@@ -18,4 +18,4 @@ def load_jpeg(directory, fname):
data = dl['data']
labels = dl['labels']
return torch.Tensor(data), torch.Tensor(labels)
return torch.Tensor(data), torch.LongTensor(labels)
......@@ -25,8 +25,8 @@ def test(model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
test_loss += F.cross_entropy(output, target, reduction='sum').item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment