Shortcuts

Source code for torch.utils.data.dataset

import bisect
import random
import warnings

from torch._utils import _accumulate
from torch import randperm
# No 'default_generator' in torch/__init__.pyi
from torch import default_generator  # type: ignore
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
from ... import Tensor, Generator

T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')


[docs]class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> T_co: raise NotImplementedError def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]': return ConcatDataset([self, other])
# No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # in pytorch/torch/utils/data/sampler.py
[docs]class IterableDataset(Dataset[T_co]): r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20))) [3, 4, 5, 6] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] """ def __iter__(self) -> Iterator[T_co]: raise NotImplementedError def __add__(self, other: Dataset[T_co]): return ChainDataset([self, other])
# No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
[docs]class TensorDataset(Dataset[Tuple[Tensor, ...]]): r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """ tensors: Tuple[Tensor, ...] def __init__(self, *tensors: Tensor) -> None: assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors" self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
[docs]class ConcatDataset(Dataset[T_co]): r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """ datasets: List[Dataset[T_co]] cumulative_sizes: List[int] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, datasets: Iterable[Dataset]) -> None: super(ConcatDataset, self).__init__() # Cannot verify that datasets is Sized assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore self.datasets = list(datasets) for d in self.datasets: assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.datasets) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.datasets[dataset_idx][sample_idx] @property def cummulative_sizes(self): warnings.warn("cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2) return self.cumulative_sizes
[docs]class ChainDataset(IterableDataset): r"""Dataset for chainning multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chainning operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """ def __init__(self, datasets: Iterable[Dataset]) -> None: super(ChainDataset, self).__init__() self.datasets = datasets def __iter__(self): for d in self.datasets: assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" for x in d: yield x def __len__(self): total = 0 for d in self.datasets: assert isinstance(d, IterableDataset), "ChainDataset only supports IterableDataset" # Cannot verify that all self.datasets are Sized total += len(d) # type: ignore return total
[docs]class BufferedShuffleDataset(IterableDataset[T_co]): r"""Dataset shuffled from the original dataset. This class is useful to shuffle an existing instance of an IterableDataset. The buffer with `buffer_size` is filled with the items from the dataset first. Then, each item will be yielded from the buffer by reservoir sampling via iterator. `buffer_size` is required to be larger than 0. For `buffer_size == 1`, the dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size` is required to be greater than or equal to the size of dataset. When it is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. And, the method to set up a random seed is different based on :attr:`num_workers`. For single-process mode (:attr:`num_workers == 0`), the random seed is required to be set before the :class:`~torch.utils.data.DataLoader` in the main process. >>> ds = BufferedShuffleDataset(dataset) >>> random.seed(...) >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable function in each worker. >>> ds = BufferedShuffleDataset(dataset) >>> def init_fn(worker_id): ... random.seed(...) >>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn))) Args: dataset (IterableDataset): The original IterableDataset. buffer_size (int): The buffer size for shuffling. """ dataset: IterableDataset[T_co] buffer_size: int def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None: super(BufferedShuffleDataset, self).__init__() assert buffer_size > 0, "buffer_size should be larger than 0" self.dataset = dataset self.buffer_size = buffer_size def __iter__(self) -> Iterator[T_co]: buf: List[T_co] = [] for x in self.dataset: if len(buf) == self.buffer_size: idx = random.randint(0, self.buffer_size - 1) yield buf[idx] buf[idx] = x else: buf.append(x) random.shuffle(buf) while buf: yield buf.pop()
[docs]class Subset(Dataset[T_co]): r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ dataset: Dataset[T_co] indices: Sequence[int] def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None: self.dataset = dataset self.indices = indices def __getitem__(self, idx): return self.dataset[self.indices[idx]] def __len__(self): return len(self.indices)
[docs]def random_split(dataset: Dataset[T], lengths: Sequence[int], generator: Optional[Generator] = default_generator) -> List[Subset[T]]: r""" Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.: >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42)) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced generator (Generator): Generator used for the random permutation. """ # Cannot verify that dataset is Sized if sum(lengths) != len(dataset): # type: ignore raise ValueError("Sum of input lengths does not equal the length of the input dataset!") indices = randperm(sum(lengths), generator=generator).tolist() return [Subset(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources