Verified Commit cbde3c94 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Add dataset conversion and loading tools

parent b59d8c91
import torch
from torchvision import datasets, transforms
def make_data(batch_size, root, name, dataset, transform, shuffle_train):
train_data = dataset('{}/{}-data'.format(root, name), train=True, download=True, transform=transform)
test_data = dataset('{}/{}-data'.format(root, name), train=False, download=True, transform=transform)
train_loader =, batch_size=batch_size, shuffle=shuffle_train)
test_loader =, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def mnist_spatial(batch_size, root, shuffle_train=True):
transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
return make_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
def cifar10_spatial(batch_size, root, shuffle_train=True):
transform = transforms.Compose([
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
return make_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
def cifar100_spatial(batch_size, root, shuffle_train=True):
transform = transforms.Compose([
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
return make_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
import argparse
from .datasets import *
from .jpeg_readwrite import *
from jpeg_codec import encode
dataset_map = {
'MNIST': mnist_spatial,
'CIFAR10': cifar10_spatial,
'CIFAR100': cifar100_spatial
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, help='Number of images to convert at a time', default=60000)
parser.add_argument('dataset', choices=dataset_map.keys(), help='Dataset to convert')
parser.add_argument('directory', help='directory to load and store datasets')
args = parser.parse_args()
train_data, test_data = dataset_map[args.dataset](args.batch_size,, shuffle_train=False)
device = torch.device('cuda')
print('Encoding training data...')
jpeg_data = []
jpeg_label = []
for data, label in train_data:
jpeg_data.append(encode(data, device=device))
print('Writing training data...')
jpeg_data =
jpeg_label =
write_jpeg(jpeg_data, jpeg_label,, 'train.npz')
print('Encoding testing data...')
jpeg_data = []
jpeg_label = []
for data, label in test_data:
jpeg_data.append(encode(data, device=device))
print('Writing testing data...')
jpeg_data =
jpeg_label =
write_jpeg(jpeg_data, jpeg_label,, 'test.npz')
import numpy as np
import torch
import os
def write_jpeg(data, labels, directory, fname):
os.makedirs(directory, exist_ok=True)
path = '{}/{}'.format(directory, fname)
with open(path, 'wb') as f:
np.savez(f, data=data.cpu().numpy(), labels=labels.cpu().numpy())
def load_jpeg(directory, fname):
path = '{}/{}'.format(directory, fname)
with open(path) as f:
dl = np.load(f)
data = dl['data']
label = dl['label']
return torch.Tensor(data), torch.Tensor(label)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment