Verified Commit 65a729a5 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Simple training experiment

parent cc5094a1
import models
import torch
import torch.optim as optim
import argparse
import data
device = torch.device('cuda')
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes'])
jpeg_model = models.JpegResNetExact(spatial_model).to(device)
train_set, test_set = data.jpeg_dataset_map[args.dataset](128,
optimizer = optim.Adam(jpeg_model.parameters())
for e in range(5):
models.train(jpeg_model, device, train_set, optimizer, e)
models.test(jpeg_model, device, test_set)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment