jpeg_dataset.py 342 Bytes
Newer Older
Max Ehrlich's avatar
Max Ehrlich committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from jpeg_readwrite import load_jpeg


class JpegDataset(torch.utils.data.Dataset):
    def __init__(self, directory, fname):
        self.data, self.labels = load_jpeg(directory, fname)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.dataset[idx], self.labels[idx]