Source code for kitcar_ml.utils.data.data_loader.transforms.albumentations_transform

from typing import List, Tuple

import albumentations
import numpy as np

from .bbox_image_transform import BBoxImageTransform


[docs]class AlbumentationsComposer(albumentations.Compose): def __call__(self, image, bboxes, labels): tf = super().__call__(image=image, bboxes=bboxes, class_labels=labels) return tf["image"], tf["bboxes"], tf["class_labels"]
[docs]class AlbumentationsTransform(BBoxImageTransform): def __init__( self, augmentations: List[albumentations.BasicTransform], **compose_kwargs, ): """Initialize an albumentations transformation. Args: augmentations: Transformations that will be applied to augment loaded data **compose_kwargs: Other keyword arguments passed to albumentations.Compose. """ # create augmenter self.augmentator = AlbumentationsComposer(augmentations, **compose_kwargs) def __call__( self, batch: List[Tuple[np.ndarray, List[np.ndarray], List[str]]] ) -> List[Tuple[np.ndarray, List[np.ndarray], List[str]]]: return [self.augmentator(*batch_element) for batch_element in batch]