Verified Commit e29753b4 authored by Max Ehrlich's avatar Max Ehrlich
Browse files

Fix bugs in jpeg data loading

parent 3b8e93f7
from .datasets import *
\ No newline at end of file
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from jpeg_dataset import JpegDataset from .jpeg_dataset import JpegDataset
def spatial_data(batch_size, root, name, dataset, transform, shuffle_train): def spatial_data(batch_size, root, name, dataset, transform, shuffle_train):
...@@ -64,3 +64,31 @@ def cifar10_jpeg(batch_size, root, shuffle_train=True): ...@@ -64,3 +64,31 @@ def cifar10_jpeg(batch_size, root, shuffle_train=True):
def cifar100_jpeg(batch_size, root, shuffle_train=True): def cifar100_jpeg(batch_size, root, shuffle_train=True):
directory = '{}/{}'.format(root, 'CIFAR100') directory = '{}/{}'.format(root, 'CIFAR100')
return jpeg_data(batch_size, directory, shuffle_train) return jpeg_data(batch_size, directory, shuffle_train)
spatial_dataset_map = {
'MNIST': mnist_spatial,
'CIFAR10': cifar10_spatial,
'CIFAR100': cifar100_spatial
}
jpeg_dataset_map = {
'MNIST': mnist_jpeg,
'CIFAR10': cifar10_jpeg,
'CIFAR100': cifar100_jpeg
}
dataset_info = {
'MNIST': {
'channels': 1,
'classes': 10
},
'CIFAR10': {
'channels': 3,
'classes': 10
},
'CIFAR100': {
'channels': 3,
'classes': 100,
}
}
import torch import torch
from jpeg_readwrite import load_jpeg from .jpeg_readwrite import load_jpeg
class JpegDataset(torch.utils.data.Dataset): class JpegDataset(torch.utils.data.Dataset):
...@@ -10,4 +10,4 @@ class JpegDataset(torch.utils.data.Dataset): ...@@ -10,4 +10,4 @@ class JpegDataset(torch.utils.data.Dataset):
return self.data.shape[0] return self.data.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
return self.dataset[idx], self.labels[idx] return self.data[idx], self.labels[idx]
...@@ -13,10 +13,9 @@ def write_jpeg(data, labels, directory, fname): ...@@ -13,10 +13,9 @@ def write_jpeg(data, labels, directory, fname):
def load_jpeg(directory, fname): def load_jpeg(directory, fname):
path = '{}/{}'.format(directory, fname) path = '{}/{}'.format(directory, fname)
with open(path) as f: with open(path, 'rb') as f:
dl = np.load(f) dl = np.load(f)
data = dl['data']
labels = dl['labels']
data = dl['data'] return torch.Tensor(data), torch.Tensor(labels)
label = dl['label']
return torch.Tensor(data), torch.Tensor(label)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment