Skip to main content

classy.data.dataset.base

Functions

batchify

def batchify(
    tensors: List[torch.Tensor],
    padding_value: int,
) ‑> torch.Tensor

batchify_matrices

def batchify_matrices(
    tensors: List[torch.Tensor],
    padding_value: int,
) ‑> torch.Tensor

Classes

BaseDataset

class BaseDataset(torch.utils.data.dataset.IterableDataset)

An iterable Dataset.

All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream.

All subclasses should overwrite :meth:__iter__, which would return an iterator of samples in this dataset.

When a subclass is used with :class:~torch.utils.data.DataLoader, each item in the dataset will be yielded from the :class:~torch.utils.data.DataLoaderiterator. When :attr:num_workers > 0, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:~torch.utils.data.get_worker_info, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:__iter__ method or the :class:~torch.utils.data.DataLoader 's :attr:worker_init_fn option to modify each copy's behavior.

Example 1: splitting workload across all workers in :meth:__iter__::

>>> class MyIterableDataset(torch.utils.data.IterableDataset): ...     def __init__(self, start, end): ...         super(MyIterableDataset).__init__() ...         assert end > start, "this example code only works with end >= start" ...         self.start = start ...         self.end = end ... ...     def __iter__(self): ...         worker_info = torch.utils.data.get_worker_info() ...         if worker_info is None:  # single-process data loading, return the full iterator ...             iter_start = self.start ...             iter_end = self.end ...         else:  # in a worker process ...             # split workload ...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ...             worker_id = worker_info.id ...             iter_start = self.start + worker_id * per_worker ...             iter_end = min(iter_start + per_worker, self.end) ...         return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6]

>>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 5, 4, 6]

>>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20))) [3, 4, 5, 6]

Example 2: splitting workload across all workers using :attr:worker_init_fn::

>>> class MyIterableDataset(torch.utils.data.IterableDataset): ...     def __init__(self, start, end): ...         super(MyIterableDataset).__init__() ...         assert end > start, "this example code only works with end >= start" ...         self.start = start ...         self.end = end ... ...     def __iter__(self): ...         return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a <code>worker_init_fn</code> that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ...

>>> # Mult-process loading with the custom <code>worker_init_fn</code> >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6]

>>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) [3, 4, 5, 6]

Subclasses (1)

__init__

def __init__(
    samples_iterator: Callable[[], Iterator[ClassySample]],
    vocabulary: Vocabulary,
    for_inference: bool,
    batch_size: Optional[int] = None,
    tokens_per_batch: Optional[int] = None,
    max_batch_size: Optional[int] = None,
    batching_fields: Optional[List[str]] = None,
    section_size: Optional[int] = None,
    prebatch: bool = False,
    materialize: bool = False,
    drop_last: bool = False,
    min_length: int = -1,
    max_length: int = -1,
)

dataset_iterator_func

def dataset_iterator_func(
    self,
)

materialize_batches

def materialize_batches(
    self,
    dataset_elements: List[Dict[str, Any]],
) ‑> Generator[Dict[str, Any], None, None]

materialize_dataset

def materialize_dataset(
    self,
) ‑> None

prebatch_elements

def prebatch_elements(
    self,
    dataset_elements: List[~T],
)

adapt_dataset_from

@classmethod
def adapt_dataset_from(
    training_dataset: omegaconf.dictconfig.DictConfig,
    setting: str,
)

fit_vocabulary

@classmethod
def fit_vocabulary(
    samples: Iterator[ClassySample],
) ‑> Vocabulary

from_file

@classmethod
def from_file(
    path: Union[str, Dict[str, DataDriver]],
    data_driver: Optional[DataDriver] = None,
    vocabulary: Vocabulary = None,
    **kwargs,
) ‑> BaseDataset

from_samples

@classmethod
def from_samples(
    samples: Iterator[ClassySample],
    vocabulary: Vocabulary,
    **kwargs,
)

requires_vocab

@classmethod
def requires_vocab() ‑> bool