Source code for sconce.data_generators.single_class_image_data_generator

from sconce.data_generators import DataGenerator, ImageMixin
from sconce.datasets import Subset
from torchvision import datasets, transforms

import numpy as np
import os
import pandas as pd
import tempfile


[docs]class SingleClassImageDataGenerator(DataGenerator, ImageMixin): """ An ImageDataGenerator class for use when each image belongs to exactly one class. New in 0.10.0 """ def _get_class_df(self, targets=None): dataset = self.dataset rows = [] if targets is None: if hasattr(dataset, 'targets'): targets = dataset.targets else: raise RuntimeError("No targets were supplied, and the dataset doesn't " "have a 'targets' attribute") for target in targets: row = {} for _class in dataset.classes: idx = dataset.class_to_idx[_class] if target == idx: row[_class] = True else: row[_class] = False rows.append(row) return pd.DataFrame(rows)
[docs] @classmethod def from_torchvision(cls, batch_size=500, data_location=None, dataset_class=datasets.MNIST, fraction=1.0, num_workers=0, pin_memory=True, shuffle=True, train=True, transform=transforms.ToTensor()): """ Create a DataGenerator from a torchvision dataset class. Arguments: batch_size (int): how large the yielded `inputs` and `targets` should be. See :py:class:`DataLoader` for details. data_location (path): where downloaded dataset should be stored. If ``None`` a system dependent temporary location will be used. dataset_class (class): a torchvision dataset class that supports constructor arguments {'root', 'train', 'download', 'transform'}. For example, MNIST, FashionMnist, CIFAR10, or CIFAR100. fraction (float): (0.0 - 1.0] how much of the original dataset's data to use. num_workers (int): how many subprocesses to use for data loading. See :py:class:`DataLoader` for details. pin_memory (bool): if ``True``, the data loader will copy tensors into CUDA pinned memory before returning them. See :py:class:`DataLoader` for details. shuffle (bool): set to ``True`` to have the data reshuffled at every epoch. See :py:class:`DataLoader` for details. train (bool): if ``True``, creates dataset from training set, otherwise creates from test set. transform (callable): a function/transform that takes in an PIL image and returns a transformed version. """ assert(fraction > 0.0) assert(fraction <= 1.0) if data_location is None: data_location = os.path.join(tempfile.gettempdir(), dataset_class.__name__) dataset = dataset_class(data_location, train=train, download=True, transform=transform) indices = [int(x) for x in np.linspace( start=0, stop=len(dataset) - 1, num=int(len(dataset) * fraction))] subset = Subset(dataset, indices=indices) return cls.from_dataset(subset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle)
[docs] @classmethod def from_image_folder(self, root, loader_kwargs=None, **dataset_kwargs): """ Create a DataGenerator from a folder of images. See :py:class:`torchvision.datasets.ImageFolder`. Arguments: root (path): the root directory path. loader_kwargs (dict): keyword args provided to the DataLoader constructor. **dataset_kwargs: keyword args provided to the :py:class:`torchvision.datasets.ImageFolder` constructor. """ if loader_kwargs is None: loader_kwargs = {} dataset = datasets.ImageFolder(root=root, **dataset_kwargs) return self.from_dataset(dataset, **loader_kwargs)