...
 
Commits (5)
import torch
from torchvision import datasets, transforms
from jpeg_dataset import JpegDataset
def spatial_data(batch_size, root, name, dataset, transform, shuffle_train):
train_data = dataset('{}/{}-data'.format(root, name), train=True, download=True, transform=transform)
test_data = dataset('{}/{}-data'.format(root, name), train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def mnist_spatial(batch_size, root, shuffle_train=True):
transform = transforms.Compose([
transforms.Pad(2),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
return spatial_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
def cifar10_spatial(batch_size, root, shuffle_train=True):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
return spatial_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
def cifar100_spatial(batch_size, root, shuffle_train=True):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
return spatial_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)
def jpeg_data(batch_size, directory, shuffle_train):
train_data = JpegDataset(directory=directory, fname='train.npz')
test_data = JpegDataset(directory=directory, fname='test.npz')
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
return train_loader, test_loader
def mnist_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'MNIST')
return jpeg_data(batch_size, directory, shuffle_train)
def cifar10_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR10')
return jpeg_data(batch_size, directory, shuffle_train)
def cifar100_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR100')
return jpeg_data(batch_size, directory, shuffle_train)
import torch
from jpeg_readwrite import load_jpeg
class JpegDataset(torch.utils.data.Dataset):
def __init__(self, directory, fname):
self.data, self.labels = load_jpeg(directory, fname)
def __len__(self):
return self.data.shape[0]
def __getitem__(self, idx):
return self.dataset[idx], self.labels[idx]
import argparse
from .datasets import *
from .jpeg_readwrite import *
from jpeg_codec import encode
dataset_map = {
'MNIST': mnist_spatial,
'CIFAR10': cifar10_spatial,
'CIFAR100': cifar100_spatial
}
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, help='Number of images to convert at a time', default=60000)
parser.add_argument('dataset', choices=dataset_map.keys(), help='Dataset to convert')
parser.add_argument('directory', help='directory to load and store datasets')
args = parser.parse_args()
print(args)
train_data, test_data = dataset_map[args.dataset](args.batch_size, root=args.directory, shuffle_train=False)
device = torch.device('cuda')
print('Encoding training data...')
jpeg_data = []
jpeg_label = []
for data, label in train_data:
jpeg_data.append(encode(data, device=device))
jpeg_label.append(label)
print('Writing training data...')
jpeg_data = torch.cat(jpeg_data)
jpeg_label = torch.cat(jpeg_label)
write_jpeg(jpeg_data, jpeg_label, args.directory, 'train.npz')
print('Encoding testing data...')
jpeg_data = []
jpeg_label = []
for data, label in test_data:
jpeg_data.append(encode(data, device=device))
jpeg_label.append(label)
print('Writing testing data...')
jpeg_data = torch.cat(jpeg_data)
jpeg_label = torch.cat(jpeg_label)
write_jpeg(jpeg_data, jpeg_label, args.directory, 'test.npz')
import numpy as np
import torch
import os
def write_jpeg(data, labels, directory, fname):
os.makedirs(directory, exist_ok=True)
path = '{}/{}'.format(directory, fname)
with open(path, 'wb') as f:
np.savez(f, data=data.cpu().numpy(), labels=labels.cpu().numpy())
def load_jpeg(directory, fname):
path = '{}/{}'.format(directory, fname)
with open(path) as f:
dl = np.load(f)
data = dl['data']
label = dl['label']
return torch.Tensor(data), torch.Tensor(label)
......@@ -3,15 +3,15 @@ from .blocks import SpatialResBlock
class SpatialResNet(nn.Module):
def __init__(self):
def __init__(self, n_channels, n_classes):
super(SpatialResNet, self).__init__()
self.block1 = SpatialResBlock(in_channels=1, out_channels=16, downsample=False)
self.block1 = SpatialResBlock(in_channels=n_channels, out_channels=16, downsample=False)
self.block2 = SpatialResBlock(in_channels=16, out_channels=32)
self.block3 = SpatialResBlock(in_channels=32, out_channels=64)
self.averagepooling = nn.AvgPool2d(8, stride=1)
self.fc = nn.Linear(64, 10)
self.fc = nn.Linear(64, n_classes)
def forward(self, x):
out = self.block1(x)
......