Source code for kitcar_ml.utils.data.data_loader.utils

import os
from typing import List, Optional, Tuple, Union

from ..labeled_dataset import LabeledDataset
from ..unlabeled_dataset import UnlabeledDataset
from .base_data_loader import BaseDataLoader


[docs]def load_labeled_dataset( label_file: str, max_dataset_size: Optional[int], batch_size: int, sequential: bool, num_workers: int, ) -> BaseDataLoader: """Create dataloader for a labeled dataset. Args: label_file: Path to a file containing all labels max_dataset_size: Maximum amount of images to load; None means infinity batch_size: Batch size sequential: If true, takes images in order, otherwise takes them randomly num_workers: Threads for loading data """ dataset = LabeledDataset.from_yaml(label_file) dataset._base_path = os.path.dirname(label_file) # Transform datasets into dataloaders. return BaseDataLoader( dataset=dataset, max_dataset_size=max_dataset_size, batch_size=batch_size, num_workers=num_workers, shuffle=not sequential, )
[docs]def load_unpaired_unlabeled_datasets( dir_a: Union[str, List[str]], dir_b: Union[str, List[str]], max_dataset_size: Optional[int], batch_size: int, sequential: bool, num_workers: int, ) -> Tuple[BaseDataLoader, BaseDataLoader]: """Create dataloader for two unpaired and unlabeled datasets. E.g. used by cycle gan with data from two domains. Args: dir_a: path to images of domain a dir_b: path to images of domain b max_dataset_size (int): maximum amount of images to load; -1 means infinity batch_size (int): input batch size sequential (bool): if true, takes images in order, otherwise takes them randomly num_workers (int): workers for loading data """ a = UnlabeledDataset(dir_a) b = UnlabeledDataset(dir_b) # Transform datasets into dataloaders. a = BaseDataLoader( dataset=a, max_dataset_size=max_dataset_size, batch_size=batch_size, num_workers=num_workers, shuffle=not sequential, ) b = BaseDataLoader( dataset=b, max_dataset_size=max_dataset_size, batch_size=batch_size, num_workers=num_workers, shuffle=not sequential, ) return a, b
[docs]def sample_generator( dataloader: BaseDataLoader, n_samples: Optional[int] = None, ): """Generator that samples from a dataloader. Args: dataloader: Dataloader. n_samples: Number of batches of samples. None means infinity """ iter_ = iter(dataloader) i = 0 while n_samples and i < n_samples: i += 1 try: next_ = next(iter_) except StopIteration: iter_ = iter(dataloader) next_ = next(iter_) yield next_
[docs]def unpaired_sample_generator( dataloader_a: BaseDataLoader, dataloader_b: BaseDataLoader, n_samples: Optional[int] = None, ): """Generator that samples pairwise from both dataloaders. Args: dataloader_a: Domain a dataloader. dataloader_b: Domain b dataloader. n_samples: Number of batches of samples. """ iter_a = iter(dataloader_a) iter_b = iter(dataloader_b) i = 0 while n_samples and i < n_samples: i += 1 try: next_a = next(iter_a) except StopIteration: iter_a = iter(dataloader_a) next_a = next(iter_a) try: next_b = next(iter_b) except StopIteration: iter_b = iter(dataloader_b) next_b = next(iter_b) yield next_a, next_b