Source code for kitcar_ml.utils.data.data_loader.transforms.mosaic_transform
from typing import List, Tuple
import more_itertools
import numpy as np
from kitcar_ml.utils.data.data_loader.transforms.bbox_image_transform import (
BBoxImageTransform,
)
[docs]class MosaicTransform(BBoxImageTransform):
"""Transform multiple images into a grid/mosaic image.
All images before this transformation should have the same shape. A random cropping
before this transformation adds variability.
"""
def __init__(
self,
rows: int = 2,
columns: int = 2,
):
"""Initialize a mosaic transformation.
Args:
rows: Number of rows in the grid.
columns: Number of columns in the grid.
"""
self.rows = rows
self.columns = columns
@property
def imgs_per_output(self) -> int:
"""Images needed to produce one mosaic image."""
return self.rows * self.columns
def __call__(
self, batch: List[Tuple[np.ndarray, List[np.ndarray], List[str]]]
) -> List[Tuple[np.ndarray, List[np.ndarray], List[str]]]:
"""Turn a batch of images into a batch of mosaic images, the images are concatenate
together to form a grid.
Args:
batch: List of images, bboxes and labels put together in a tuple.
Returns:
The mosaic batch.
"""
split_batch = more_itertools.grouper(
batch, self.imgs_per_output, fillvalue=batch[0]
)
return [self.calculate_mosaic_image(mosaic_batch) for mosaic_batch in split_batch]
[docs] def calculate_mosaic_image(
self, mosaic_batch: List[Tuple[np.ndarray, List[np.ndarray], List[str]]]
) -> Tuple[np.ndarray, List[np.ndarray], List[str]]:
"""Calculate a mosaic image from a batch of images."""
batch_images, batch_boxes, batch_labels = zip(*mosaic_batch)
assert all(
[image.shape == batch_images[0].shape for image in batch_images]
), "All images should have the same shape, to create a mosaic image"
# Use first image, because all images should have the same height.
height, width, _ = batch_images[0].shape
mosaic_boxes = [
self.shift_box(boxes, index, height, width)
for index, boxes_per_image in enumerate(batch_boxes)
for boxes in boxes_per_image
]
mosaic_image = self.create_grid_image(batch_images)
mosaic_labels = sum(batch_labels, [])
return mosaic_image, mosaic_boxes, mosaic_labels
[docs] def shift_box(self, box: np.ndarray, index: int, height: int, width: int) -> np.ndarray:
"""Shift a bounding box dependent on the position in the grid and the height and
width of the images.
Args:
box: The bounding box that is shifted.
index: The index of the corresponding image.
height: The height of one image.
width: The width of one image.
"""
height_shift = (index // self.columns) * height
width_shift = (index % self.columns) * width
return [
box[0] + width_shift,
box[1] + height_shift,
box[2] + width_shift,
box[3] + height_shift,
]
[docs] def create_grid_image(self, images: List[np.ndarray]) -> np.ndarray:
"""Concatenate the images to a new image."""
grid_rows = [
np.concatenate(one_row, axis=1)
for one_row in more_itertools.chunked(images, self.columns)
]
return np.concatenate(grid_rows, axis=0)