from functools import cached_property
from typing import List
import numpy as np
import torch
import torchvision.transforms
from .base_data_loader import BaseDataLoader
from .transforms.bbox_image_transform import BBoxImageTransform
[docs]class BBoxDataLoader(BaseDataLoader):
"""Dataloader for bbox datasets that applies transformations."""
def __init__(
self,
*,
transforms: List[BBoxImageTransform],
torch_tf=torchvision.transforms.ToTensor(),
**kwargs,
):
self.torch_tf = torch_tf
self.transforms = transforms
# Adapt batch size if some transforms require multiple images!
kwargs["batch_size"] = kwargs.get("batch_size", 1) * self.super_sampling_factor
super().__init__(**kwargs)
@cached_property
def super_sampling_factor(self):
"""How many images are consumed to produce one training image."""
return np.product(
tuple(tf.imgs_per_output for tf in self.transforms), dtype=np.int64
).item()
@cached_property
def complete_transform(self):
"""Concatenate all transforms."""
def tf(batch):
for t in self.transforms:
batch = t(batch)
return batch
return tf
[docs] def prepare_batch(self, batch):
"""Prepare a batch before it is converted to torch."""
batch = [(np.array(img), bboxes, labels) for img, bboxes, labels in batch]
batch = self.complete_transform(batch)
images, bboxes, labels = zip(*batch)
# Turn images into torch format
images = [self.torch_tf(image) for image in images]
# Turn bboxes into torch
if not isinstance(bboxes[0], torch.Tensor):
bboxes = [torch.tensor(boxes, dtype=torch.float32) for boxes in bboxes]
# view(-1,4) ensures that empty bbox tensors have the right dimensions!
bboxes = [boxes.view(-1, 4) for boxes in bboxes]
return images, bboxes, labels
def __len__(self):
"""Return the number of data in the dataset."""
return super().__len__() * self.super_sampling_factor
def __iter__(self):
"""Return a batch of data."""
for _ in range(self.super_sampling_factor):
yield from super().__iter__()