Verified Commit f70d43ed authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Fix dataset paths

parent 3e432454
...@@ -20,7 +20,8 @@ def mnist_spatial(batch_size, root, shuffle_train=True): ...@@ -20,7 +20,8 @@ def mnist_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.1307,), (0.3081,)) transforms.Normalize((0.1307,), (0.3081,))
]) ])
return spatial_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train) directory = '{}/{}'.format(root, 'MNIST')
return spatial_data(batch_size=batch_size, root=directory, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
def cifar10_spatial(batch_size, root, shuffle_train=True): def cifar10_spatial(batch_size, root, shuffle_train=True):
...@@ -29,7 +30,8 @@ def cifar10_spatial(batch_size, root, shuffle_train=True): ...@@ -29,7 +30,8 @@ def cifar10_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
]) ])
return spatial_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train) directory = '{}/{}'.format(root, 'CIFAR10')
return spatial_data(batch_size=batch_size, root=directory, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
def cifar100_spatial(batch_size, root, shuffle_train=True): def cifar100_spatial(batch_size, root, shuffle_train=True):
...@@ -38,7 +40,8 @@ def cifar100_spatial(batch_size, root, shuffle_train=True): ...@@ -38,7 +40,8 @@ def cifar100_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]) ])
return spatial_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train) directory = '{}/{}'.format(root, 'CIFAR100')
return spatial_data(batch_size=batch_size, root=directory, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
def jpeg_data(batch_size, directory, shuffle_train): def jpeg_data(batch_size, directory, shuffle_train):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment