import itertools
from collections.abc import Iterable
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image, ImageDraw
from torchvision.transforms import transforms
[docs]def draw_boxes(image, boxes, labels):
"""Draw labeled boxes on an image."""
if boxes.ndim == 1:
boxes = boxes.view(1, 4)
if labels is not None and isinstance(labels, Iterable):
labels = [labels]
# Plot each box
draw = ImageDraw.Draw(image)
for box, label in itertools.zip_longest(boxes, labels[0]):
draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline="red")
if label:
draw.text((box[0] + 5, box[1] - 15), label, fill="red")
[docs]def show_labeled_image(
image: Union[Image.Image, np.ndarray, torch.Tensor],
boxes: Optional[torch.Tensor] = None,
labels: Optional[List[str]] = None,
save_path: Optional[str] = None,
show: bool = False,
):
"""Show the image along with the specified boxes around detected objects.
Args:
image: The image that is displayed.
boxes: The bounding boxes on the image.
labels: A list of labels for each bounding box.
save_path: The path to the folder where the image should be saved.
show: If the image should be shown interactively.
"""
if isinstance(image, torch.Tensor):
image = transforms.ToPILImage()(image)
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
if boxes is not None:
draw_boxes(image, boxes, labels)
# Show a single box or multiple if provided
if save_path:
image.save(save_path)
if show:
image.show()
image.close()