Source code for kitcar_ml.utils.data.data_loader.base_data_loader

import math
from typing import Optional, Union

from torch.utils.data import DataLoader

from ..labeled_dataset import LabeledDataset
from ..unlabeled_dataset import UnlabeledDataset


[docs]class BaseDataLoader(DataLoader): """Wrapper class of Dataset class that performs multi-threaded data loading.""" def __init__( self, *, dataset: Union[UnlabeledDataset, LabeledDataset], max_dataset_size: Optional[int] = None, **kwargs, ): """Initialize this class. Args: dataset: the dataset to load max_dataset_size: maximum amount of images to load, None means infinity **kwargs: Passed to pytorch's data loader. """ # Pass prepare_batch to be applied to all batches! if "collate_fn" not in kwargs: kwargs["collate_fn"] = self.prepare_batch super().__init__(dataset, **kwargs) print("dataset [%s] was created" % type(self.dataset).__name__) self.max_dataset_size = max_dataset_size
[docs] def prepare_batch(self, batch): return (list(item) for item in zip(*batch))
def __len__(self): """Return the number of data in the dataset.""" return math.ceil( min( len(self.dataset), self.max_dataset_size if self.max_dataset_size is not None else float("inf"), ) / self.batch_size ) def __iter__(self): """Return a batch of data.""" for i, data in enumerate(super().__iter__()): if self.max_dataset_size and (i + 1) * self.batch_size >= self.max_dataset_size: # What's going on here? # The last batch of this data set is shortened such that the total loaded # images are <= self.max_dataset_size yield list( items[: self.max_dataset_size - i * self.batch_size] for items in data ) break else: yield data