Verified Commit 1009394f authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add dataset information

parent e29753b4
...@@ -3,21 +3,15 @@ from .datasets import * ...@@ -3,21 +3,15 @@ from .datasets import *
from .jpeg_readwrite import * from .jpeg_readwrite import *
from jpeg_codec import encode from jpeg_codec import encode
dataset_map = {
'MNIST': mnist_spatial,
'CIFAR10': cifar10_spatial,
'CIFAR100': cifar100_spatial
}
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, help='Number of images to convert at a time', default=60000) parser.add_argument('--batch_size', type=int, help='Number of images to convert at a time', default=60000)
parser.add_argument('dataset', choices=dataset_map.keys(), help='Dataset to convert') parser.add_argument('dataset', choices=spatial_dataset_map.keys(), help='Dataset to convert')
parser.add_argument('directory', help='directory to load and store datasets') parser.add_argument('directory', help='directory to load and store datasets')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
train_data, test_data = dataset_map[args.dataset](args.batch_size, root=args.directory, shuffle_train=False) train_data, test_data = spatial_dataset_map[args.dataset](args.batch_size, root=args.directory, shuffle_train=False)
device = torch.device('cuda') device = torch.device('cuda')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment