Source code for kitcar_ml.utils.data.data_loader.bbox_data_loader

from functools import cached_property
from typing import List

import numpy as np
import torch
import torchvision.transforms

from .base_data_loader import BaseDataLoader
from .transforms.bbox_image_transform import BBoxImageTransform


[docs]class BBoxDataLoader(BaseDataLoader): """Dataloader for bbox datasets that applies transformations.""" def __init__( self, *, transforms: List[BBoxImageTransform], torch_tf=torchvision.transforms.ToTensor(), **kwargs, ): self.torch_tf = torch_tf self.transforms = transforms # Adapt batch size if some transforms require multiple images! kwargs["batch_size"] = kwargs.get("batch_size", 1) * self.super_sampling_factor super().__init__(**kwargs) @cached_property def super_sampling_factor(self): """How many images are consumed to produce one training image.""" return np.product( tuple(tf.imgs_per_output for tf in self.transforms), dtype=np.int64 ).item() @cached_property def complete_transform(self): """Concatenate all transforms.""" def tf(batch): for t in self.transforms: batch = t(batch) return batch return tf
[docs] def prepare_batch(self, batch): """Prepare a batch before it is converted to torch.""" batch = [(np.array(img), bboxes, labels) for img, bboxes, labels in batch] batch = self.complete_transform(batch) images, bboxes, labels = zip(*batch) # Turn images into torch format images = [self.torch_tf(image) for image in images] # Turn bboxes into torch if not isinstance(bboxes[0], torch.Tensor): bboxes = [torch.tensor(boxes, dtype=torch.float32) for boxes in bboxes] # view(-1,4) ensures that empty bbox tensors have the right dimensions! bboxes = [boxes.view(-1, 4) for boxes in bboxes] return images, bboxes, labels
def __len__(self): """Return the number of data in the dataset.""" return super().__len__() * self.super_sampling_factor def __iter__(self): """Return a batch of data.""" for _ in range(self.super_sampling_factor): yield from super().__iter__()