import os
from typing import List, Optional, Tuple, Union
from ..labeled_dataset import LabeledDataset
from ..unlabeled_dataset import UnlabeledDataset
from .base_data_loader import BaseDataLoader
[docs]def load_labeled_dataset(
label_file: str,
max_dataset_size: Optional[int],
batch_size: int,
sequential: bool,
num_workers: int,
) -> BaseDataLoader:
"""Create dataloader for a labeled dataset.
Args:
label_file: Path to a file containing all labels
max_dataset_size: Maximum amount of images to load; None means infinity
batch_size: Batch size
sequential: If true, takes images in order, otherwise takes them randomly
num_workers: Threads for loading data
"""
dataset = LabeledDataset.from_yaml(label_file)
dataset._base_path = os.path.dirname(label_file)
# Transform datasets into dataloaders.
return BaseDataLoader(
dataset=dataset,
max_dataset_size=max_dataset_size,
batch_size=batch_size,
num_workers=num_workers,
shuffle=not sequential,
)
[docs]def load_unpaired_unlabeled_datasets(
dir_a: Union[str, List[str]],
dir_b: Union[str, List[str]],
max_dataset_size: Optional[int],
batch_size: int,
sequential: bool,
num_workers: int,
) -> Tuple[BaseDataLoader, BaseDataLoader]:
"""Create dataloader for two unpaired and unlabeled datasets.
E.g. used by cycle gan with data from two domains.
Args:
dir_a: path to images of domain a
dir_b: path to images of domain b
max_dataset_size (int): maximum amount of images to load; -1 means infinity
batch_size (int): input batch size
sequential (bool): if true, takes images in order, otherwise takes them randomly
num_workers (int): workers for loading data
"""
a = UnlabeledDataset(dir_a)
b = UnlabeledDataset(dir_b)
# Transform datasets into dataloaders.
a = BaseDataLoader(
dataset=a,
max_dataset_size=max_dataset_size,
batch_size=batch_size,
num_workers=num_workers,
shuffle=not sequential,
)
b = BaseDataLoader(
dataset=b,
max_dataset_size=max_dataset_size,
batch_size=batch_size,
num_workers=num_workers,
shuffle=not sequential,
)
return a, b
[docs]def sample_generator(
dataloader: BaseDataLoader,
n_samples: Optional[int] = None,
):
"""Generator that samples from a dataloader.
Args:
dataloader: Dataloader.
n_samples: Number of batches of samples. None means infinity
"""
iter_ = iter(dataloader)
i = 0
while n_samples and i < n_samples:
i += 1
try:
next_ = next(iter_)
except StopIteration:
iter_ = iter(dataloader)
next_ = next(iter_)
yield next_
[docs]def unpaired_sample_generator(
dataloader_a: BaseDataLoader,
dataloader_b: BaseDataLoader,
n_samples: Optional[int] = None,
):
"""Generator that samples pairwise from both dataloaders.
Args:
dataloader_a: Domain a dataloader.
dataloader_b: Domain b dataloader.
n_samples: Number of batches of samples.
"""
iter_a = iter(dataloader_a)
iter_b = iter(dataloader_b)
i = 0
while n_samples and i < n_samples:
i += 1
try:
next_a = next(iter_a)
except StopIteration:
iter_a = iter(dataloader_a)
next_a = next(iter_a)
try:
next_b = next(iter_b)
except StopIteration:
iter_b = iter(dataloader_b)
next_b = next(iter_b)
yield next_a, next_b