Verified Commit 3b8e93f7 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add jpeg loading tools

parent ee122b95
import torch
from torchvision import datasets, transforms
from jpeg_dataset import JpegDataset
def make_data(batch_size, root, name, dataset, transform, shuffle_train):
def spatial_data(batch_size, root, name, dataset, transform, shuffle_train):
train_data = dataset('{}/{}-data'.format(root, name), train=True, download=True, transform=transform)
test_data = dataset('{}/{}-data'.format(root, name), train=False, download=True, transform=transform)
......@@ -19,7 +20,7 @@ def mnist_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.1307,), (0.3081,))
])
return make_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
return spatial_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
def cifar10_spatial(batch_size, root, shuffle_train=True):
......@@ -28,7 +29,7 @@ def cifar10_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
return make_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
return spatial_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
def cifar100_spatial(batch_size, root, shuffle_train=True):
......@@ -37,4 +38,29 @@ def cifar100_spatial(batch_size, root, shuffle_train=True):
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
return make_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
return spatial_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
def jpeg_data(batch_size, directory, shuffle_train):
train_data = JpegDataset(directory=directory, fname='train.npz')
test_data = JpegDataset(directory=directory, fname='test.npz')
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def mnist_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'MNIST')
return jpeg_data(batch_size, directory, shuffle_train)
def cifar10_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR10')
return jpeg_data(batch_size, directory, shuffle_train)
def cifar100_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR100')
return jpeg_data(batch_size, directory, shuffle_train)
import torch
from jpeg_readwrite import load_jpeg
class JpegDataset(torch.utils.data.Dataset):
def __init__(self, directory, fname):
self.data, self.labels = load_jpeg(directory, fname)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
return self.dataset[idx], self.labels[idx]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment