spatial_throughput.py 1.25 KB
Newer Older
Max Ehrlich's avatar
Max Ehrlich committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
import models
import torch
import torch.optim as optim
import argparse
import data
import time

device = torch.device('cuda')
torch.backends.cudnn.enabled = False

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', choices=data.spatial_dataset_map.keys(), help='Dataset to use')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--data', help='Root folder for the dataset')
args = parser.parse_args()

spatial_dataset = data.spatial_dataset_map[args.dataset](args.batch_size, args.data)
dataset_info = data.dataset_info[args.dataset]
spatial_model = models.SpatialResNet(dataset_info['channels'], dataset_info['classes']).to(device)
optimizer = optim.Adam(spatial_model.parameters())


t0 = time.perf_counter()
models.train(spatial_model, device, spatial_dataset[0], optimizer, 0)
torch.cuda.synchronize()
t1 = time.perf_counter()

training_time = t1 - t0

t0 = time.perf_counter()
models.test(spatial_model, device, spatial_dataset[1])
torch.cuda.synchronize()
t1 = time.perf_counter()
testing_time = t1 - t0

with open('{}_spatial_throughput.csv'.format(args.dataset), 'w') as f:
    f.write('Training, Testing\n')
    f.write('{}, {}\n'.format(training_time / len(spatial_dataset[0]), testing_time / len(spatial_dataset[1])))