Source code for kitcar_ml.utils.data.data_loader.transforms.albumentations_transform
from typing import List, Tuple
import albumentations
import numpy as np
from .bbox_image_transform import BBoxImageTransform
[docs]class AlbumentationsComposer(albumentations.Compose):
def __call__(self, image, bboxes, labels):
tf = super().__call__(image=image, bboxes=bboxes, class_labels=labels)
return tf["image"], tf["bboxes"], tf["class_labels"]
[docs]class AlbumentationsTransform(BBoxImageTransform):
def __init__(
self,
augmentations: List[albumentations.BasicTransform],
**compose_kwargs,
):
"""Initialize an albumentations transformation.
Args:
augmentations: Transformations that will be applied to augment loaded data
**compose_kwargs: Other keyword arguments passed to albumentations.Compose.
"""
# create augmenter
self.augmentator = AlbumentationsComposer(augmentations, **compose_kwargs)
def __call__(
self, batch: List[Tuple[np.ndarray, List[np.ndarray], List[str]]]
) -> List[Tuple[np.ndarray, List[np.ndarray], List[str]]]:
return [self.augmentator(*batch_element) for batch_element in batch]