Verified Commit 1009394f authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add dataset information

parent e29753b4
......@@ -3,21 +3,15 @@ from .datasets import *
from .jpeg_readwrite import *
from jpeg_codec import encode
dataset_map = {
'MNIST': mnist_spatial,
'CIFAR10': cifar10_spatial,
'CIFAR100': cifar100_spatial
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, help='Number of images to convert at a time', default=60000)
parser.add_argument('dataset', choices=dataset_map.keys(), help='Dataset to convert')
parser.add_argument('dataset', choices=spatial_dataset_map.keys(), help='Dataset to convert')
parser.add_argument('directory', help='directory to load and store datasets')
args = parser.parse_args()
train_data, test_data = dataset_map[args.dataset](args.batch_size,, shuffle_train=False)
train_data, test_data = spatial_dataset_map[args.dataset](args.batch_size,, shuffle_train=False)
device = torch.device('cuda')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment