Source code for kitcar_ml.utils.visualization

import itertools
from collections.abc import Iterable
from typing import List, Optional, Union

import numpy as np
import torch
from PIL import Image, ImageDraw
from torchvision.transforms import transforms


[docs]def draw_boxes(image, boxes, labels): """Draw labeled boxes on an image.""" if boxes.ndim == 1: boxes = boxes.view(1, 4) if labels is not None and isinstance(labels, Iterable): labels = [labels] # Plot each box draw = ImageDraw.Draw(image) for box, label in itertools.zip_longest(boxes, labels[0]): draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline="red") if label: draw.text((box[0] + 5, box[1] - 15), label, fill="red")
[docs]def show_labeled_image( image: Union[Image.Image, np.ndarray, torch.Tensor], boxes: Optional[torch.Tensor] = None, labels: Optional[List[str]] = None, save_path: Optional[str] = None, show: bool = False, ): """Show the image along with the specified boxes around detected objects. Args: image: The image that is displayed. boxes: The bounding boxes on the image. labels: A list of labels for each bounding box. save_path: The path to the folder where the image should be saved. show: If the image should be shown interactively. """ if isinstance(image, torch.Tensor): image = transforms.ToPILImage()(image) if isinstance(image, np.ndarray): image = Image.fromarray(image) if boxes is not None: draw_boxes(image, boxes, labels) # Show a single box or multiple if provided if save_path: image.save(save_path) if show: image.show() image.close()