Verified Commit 3b8e93f7 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add jpeg loading tools

parent ee122b95
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from jpeg_dataset import JpegDataset
def make_data(batch_size, root, name, dataset, transform, shuffle_train): def spatial_data(batch_size, root, name, dataset, transform, shuffle_train):
train_data = dataset('{}/{}-data'.format(root, name), train=True, download=True, transform=transform) train_data = dataset('{}/{}-data'.format(root, name), train=True, download=True, transform=transform)
test_data = dataset('{}/{}-data'.format(root, name), train=False, download=True, transform=transform) test_data = dataset('{}/{}-data'.format(root, name), train=False, download=True, transform=transform)
...@@ -19,7 +20,7 @@ def mnist_spatial(batch_size, root, shuffle_train=True): ...@@ -19,7 +20,7 @@ def mnist_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.1307,), (0.3081,)) transforms.Normalize((0.1307,), (0.3081,))
]) ])
return make_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train) return spatial_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
def cifar10_spatial(batch_size, root, shuffle_train=True): def cifar10_spatial(batch_size, root, shuffle_train=True):
...@@ -28,7 +29,7 @@ def cifar10_spatial(batch_size, root, shuffle_train=True): ...@@ -28,7 +29,7 @@ def cifar10_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
]) ])
return make_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train) return spatial_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
def cifar100_spatial(batch_size, root, shuffle_train=True): def cifar100_spatial(batch_size, root, shuffle_train=True):
...@@ -37,4 +38,29 @@ def cifar100_spatial(batch_size, root, shuffle_train=True): ...@@ -37,4 +38,29 @@ def cifar100_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]) ])
return make_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train) return spatial_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
def jpeg_data(batch_size, directory, shuffle_train):
train_data = JpegDataset(directory=directory, fname='train.npz')
test_data = JpegDataset(directory=directory, fname='test.npz')
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def mnist_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'MNIST')
return jpeg_data(batch_size, directory, shuffle_train)
def cifar10_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR10')
return jpeg_data(batch_size, directory, shuffle_train)
def cifar100_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR100')
return jpeg_data(batch_size, directory, shuffle_train)
import torch
from jpeg_readwrite import load_jpeg
class JpegDataset(torch.utils.data.Dataset):
def __init__(self, directory, fname):
self.data, self.labels = load_jpeg(directory, fname)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
return self.dataset[idx], self.labels[idx]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment