Dont retain the graph, it wastes memory

parent a99367a8
......@@ -9,7 +9,7 @@ def train(model, device, train_loader, optimizer, epoch):
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward(retain_graph=True)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment