import math
from typing import Optional, Union
from torch.utils.data import DataLoader
from ..labeled_dataset import LabeledDataset
from ..unlabeled_dataset import UnlabeledDataset
[docs]class BaseDataLoader(DataLoader):
"""Wrapper class of Dataset class that performs multi-threaded data loading."""
def __init__(
self,
*,
dataset: Union[UnlabeledDataset, LabeledDataset],
max_dataset_size: Optional[int] = None,
**kwargs,
):
"""Initialize this class.
Args:
dataset: the dataset to load
max_dataset_size: maximum amount of images to load, None means infinity
**kwargs: Passed to pytorch's data loader.
"""
# Pass prepare_batch to be applied to all batches!
if "collate_fn" not in kwargs:
kwargs["collate_fn"] = self.prepare_batch
super().__init__(dataset, **kwargs)
print("dataset [%s] was created" % type(self.dataset).__name__)
self.max_dataset_size = max_dataset_size
[docs] def prepare_batch(self, batch):
return (list(item) for item in zip(*batch))
def __len__(self):
"""Return the number of data in the dataset."""
return math.ceil(
min(
len(self.dataset),
self.max_dataset_size
if self.max_dataset_size is not None
else float("inf"),
)
/ self.batch_size
)
def __iter__(self):
"""Return a batch of data."""
for i, data in enumerate(super().__iter__()):
if self.max_dataset_size and (i + 1) * self.batch_size >= self.max_dataset_size:
# What's going on here?
# The last batch of this data set is shortened such that the total loaded
# images are <= self.max_dataset_size
yield list(
items[: self.max_dataset_size - i * self.batch_size] for items in data
)
break
else:
yield data