datasets.py 2.6 KB
Newer Older
1
2
import torch
from torchvision import datasets, transforms
Max Ehrlich's avatar
Max Ehrlich committed
3
from jpeg_dataset import JpegDataset
4
5


Max Ehrlich's avatar
Max Ehrlich committed
6
def spatial_data(batch_size, root, name, dataset, transform, shuffle_train):
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    train_data = dataset('{}/{}-data'.format(root, name), train=True, download=True, transform=transform)
    test_data = dataset('{}/{}-data'.format(root, name), train=False, download=True, transform=transform)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def mnist_spatial(batch_size, root, shuffle_train=True):
    transform = transforms.Compose([
        transforms.Pad(2),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

Max Ehrlich's avatar
Max Ehrlich committed
23
    return spatial_data(batch_size=batch_size, root=root, name='MNIST-spatial', dataset=datasets.MNIST, transform=transform, shuffle_train=shuffle_train)
24
25
26
27
28
29
30
31


def cifar10_spatial(batch_size, root, shuffle_train=True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])

Max Ehrlich's avatar
Max Ehrlich committed
32
    return spatial_data(batch_size=batch_size, root=root, name='CIFAR10-spatial', dataset=datasets.CIFAR10, transform=transform, shuffle_train=shuffle_train)
33
34
35
36
37
38
39
40


def cifar100_spatial(batch_size, root, shuffle_train=True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

Max Ehrlich's avatar
Max Ehrlich committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    return spatial_data(batch_size=batch_size, root=root, name='CIFAR100-spatial', dataset=datasets.CIFAR100, transform=transform, shuffle_train=shuffle_train)


def jpeg_data(batch_size, directory, shuffle_train):
    train_data = JpegDataset(directory=directory, fname='train.npz')
    test_data = JpegDataset(directory=directory, fname='test.npz')

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=shuffle_train)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


def mnist_jpeg(batch_size, root, shuffle_train=True):
    directory = '{}/{}'.format(root, 'MNIST')
    return jpeg_data(batch_size, directory, shuffle_train)


def cifar10_jpeg(batch_size, root, shuffle_train=True):
    directory = '{}/{}'.format(root, 'CIFAR10')
    return jpeg_data(batch_size, directory, shuffle_train)


def cifar100_jpeg(batch_size, root, shuffle_train=True):
    directory = '{}/{}'.format(root, 'CIFAR100')
    return jpeg_data(batch_size, directory, shuffle_train)