Verified Commit e249d6e7 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add relu accuracy experiment

parent df707d83
import models
import torch
import torch.optim as optim
import argparse
import data
device = torch.device('cuda')
parser = argparse.ArgumentParser()
parser.add_argument('--models', type=int, help='Number of models to use')
parser.add_argument('--epochs', type=int, help='Number of epochs to train for')
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()
spatial_accuracies = 0
asm_accuracies = torch.zeros(15)
apx_accuracies = torch.zeros(15)
for m in range(args.models):
print('Train spatial model {}/{}'.format(m+1, args.models))
spatial_dataset = data.spatial_dataset_map[args.dataset](args.batch_size, args.data)
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes']).to(device)
optimizer = optim.Adam(spatial_model.parameters())
for e in range(args.epochs):
models.train(spatial_model, device, spatial_dataset[0], optimizer, e)
models.test(spatial_model, device, spatial_dataset[1])
acc = models.test(spatial_model, device, spatial_dataset[1])
spatial_accuracies += acc
for f in range(15):
print('Convert model to ASM JPEG with {} spatial frequencies'.format(f))
jpeg_model = models.JpegResNet(spatial_model, n_freqs=f).to(device)
jpeg_model.explode_all()
print('Test ASM JPEG model')
jpeg_dataset = data.jpeg_dataset_map[args.dataset](args.batch_size, args.data)
acc = models.test(jpeg_model, device, jpeg_dataset[1])
asm_accuracies[f] += acc
print('Convert model to APX JPEG with {} spatial frequencies'.format(f))
jpeg_model = models.JpegResNetApx(spatial_model, n_freqs=f).to(device)
jpeg_model.explode_all()
print('Test APX JPEG model')
acc = models.test(jpeg_model, device, jpeg_dataset[1])
apx_accuracies[f] += acc
spatial_accuracies /= args.models
asm_accuracies /= args.models
apx_accuracies /= args.models
with open('{}_relu_accuracy.csv'.format(args.dataset), 'w') as f:
f.write('Spatial, ASM, APX\n')
for i in range(15):
f.write('{}, {}, {}\n'.format(spatial_accuracies, asm_accuracies[i], apx_accuracies[i]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment