Verified Commit 84dd7d06 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add model conversion experiment code

parent 1009394f
import models
import torch
import torch.optim as optim
import argparse
import data
import models
device = torch.device('cuda')
parser = argparse.ArgumentParser()
parser.add_argument('--models', type=int, help='Number of models to use')
parser.add_argument('--epochs', type=int, help='Number of epochs to train for')
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()
spatial_accuracies = 0
jpeg_accuracies = 0
for m in range(args.models):
print('Train spatial model {}/{}'.format(m, args.models))
spatial_dataset = data.spatial_dataset_map[args.dataset](args.batch_size, args.data)
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes']).to(device)
optimizer = optim.Adam(spatial_model.parameters())
for e in range(args.epochs):
models.train(spatial_model, device, spatial_dataset[0], optimizer, e)
models.test(spatial_model, device, spatial_dataset[1])
acc = models.test(spatial_model, device, spatial_dataset[1])
spatial_accuracies += acc
print('Convert model to JPEG')
jpeg_model = models.JpegResNetExact(spatial_model).to(device)
print('Test JPEG model')
jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data)
acc = models.test(jpeg_model, device, jpeg_dataset[1])
jpeg_accuracies += acc
spatial_accuracies /= args.models
jpeg_accuracies /= args.models
print('Report')
print('======')
print('Number of Models: {}'.format(args.n_models))
print('Average Spatial Accuracy: {}'.format(spatial_accuracies))
print('Average JPEG Accuracy: {}'.format(jpeg_accuracies))
print('Deviation: {}'.format(spatial_accuracies - jpeg_accuracies))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment