Source code for kitcar_ml.utils.data.data_loader.transforms.mosaic_transform

from typing import List, Tuple

import more_itertools
import numpy as np

from kitcar_ml.utils.data.data_loader.transforms.bbox_image_transform import (
    BBoxImageTransform,
)


[docs]class MosaicTransform(BBoxImageTransform): """Transform multiple images into a grid/mosaic image. All images before this transformation should have the same shape. A random cropping before this transformation adds variability. """ def __init__( self, rows: int = 2, columns: int = 2, ): """Initialize a mosaic transformation. Args: rows: Number of rows in the grid. columns: Number of columns in the grid. """ self.rows = rows self.columns = columns @property def imgs_per_output(self) -> int: """Images needed to produce one mosaic image.""" return self.rows * self.columns def __call__( self, batch: List[Tuple[np.ndarray, List[np.ndarray], List[str]]] ) -> List[Tuple[np.ndarray, List[np.ndarray], List[str]]]: """Turn a batch of images into a batch of mosaic images, the images are concatenate together to form a grid. Args: batch: List of images, bboxes and labels put together in a tuple. Returns: The mosaic batch. """ split_batch = more_itertools.grouper( batch, self.imgs_per_output, fillvalue=batch[0] ) return [self.calculate_mosaic_image(mosaic_batch) for mosaic_batch in split_batch]
[docs] def calculate_mosaic_image( self, mosaic_batch: List[Tuple[np.ndarray, List[np.ndarray], List[str]]] ) -> Tuple[np.ndarray, List[np.ndarray], List[str]]: """Calculate a mosaic image from a batch of images.""" batch_images, batch_boxes, batch_labels = zip(*mosaic_batch) assert all( [image.shape == batch_images[0].shape for image in batch_images] ), "All images should have the same shape, to create a mosaic image" # Use first image, because all images should have the same height. height, width, _ = batch_images[0].shape mosaic_boxes = [ self.shift_box(boxes, index, height, width) for index, boxes_per_image in enumerate(batch_boxes) for boxes in boxes_per_image ] mosaic_image = self.create_grid_image(batch_images) mosaic_labels = sum(batch_labels, []) return mosaic_image, mosaic_boxes, mosaic_labels
[docs] def shift_box(self, box: np.ndarray, index: int, height: int, width: int) -> np.ndarray: """Shift a bounding box dependent on the position in the grid and the height and width of the images. Args: box: The bounding box that is shifted. index: The index of the corresponding image. height: The height of one image. width: The width of one image. """ height_shift = (index // self.columns) * height width_shift = (index % self.columns) * width return [ box[0] + width_shift, box[1] + height_shift, box[2] + width_shift, box[3] + height_shift, ]
[docs] def create_grid_image(self, images: List[np.ndarray]) -> np.ndarray: """Concatenate the images to a new image.""" grid_rows = [ np.concatenate(one_row, axis=1) for one_row in more_itertools.chunked(images, self.columns) ] return np.concatenate(grid_rows, axis=0)