Verified Commit d4208bd3 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add relu training experiment

parent 5b29b727
import models
import torch
import torch.optim as optim
import argparse
import data
device = torch.device('cuda')
parser = argparse.ArgumentParser()
parser.add_argument('--models', type=int, help='Number of models to use')
parser.add_argument('--epochs', type=int, help='Number of epochs to train for')
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()
spatial_accuracies = 0
asm_accuracies = torch.zeros(15)
apx_accuracies = torch.zeros(15)
for m in range(args.models):
print('Train spatial model {}/{}'.format(m+1, args.models))
spatial_dataset = data.spatial_dataset_map[args.dataset](args.batch_size, args.data)
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes']).to(device)
optimizer = optim.Adam(spatial_model.parameters())
for e in range(args.epochs):
models.train(spatial_model, device, spatial_dataset[0], optimizer, e)
acc = models.test(spatial_model, device, spatial_dataset[1])
spatial_accuracies += acc
jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data)
for f in range(15):
print('Train ASM JPEG with {} spatial frequencies'.format(f))
jpeg_model = models.JpegResNet(models.SpatialResNet(dataset_info['channels'], dataset_info['classes']), n_freqs=f).to(device)
optimizer = optim.Adam(jpeg_model.parameters())
for e in range(args.epochs):
models.train(jpeg_model, device, jpeg_dataset[0], optimizer, e)
acc = models.test(jpeg_model, device, jpeg_dataset[1])
asm_accuracies[f] += acc
print('Train APX JPEG with {} spatial frequencies'.format(f))
jpeg_model = models.JpegResNetApx(models.SpatialResNet(dataset_info['channels'], dataset_info['classes']), n_freqs=f).to(device)
optimizer = optim.Adam(jpeg_model.parameters())
for e in range(args.epochs):
models.train(jpeg_model, device, jpeg_dataset[0], optimizer, e)
acc = models.test(jpeg_model, device, jpeg_dataset[1])
apx_accuracies[f] += acc
spatial_accuracies /= args.models
asm_accuracies /= args.models
apx_accuracies /= args.models
with open('{}_relu_training.csv'.format(args.dataset), 'w') as f:
f.write('Spatial, ASM, APX\n')
for i in range(15):
f.write('{}, {}, {}\n'.format(spatial_accuracies, asm_accuracies[i], apx_accuracies[i]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment