import os
import random
from typing import Optional
import cv2
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset
[docs]class VisualLabeledDataset(LabeledDataset):
"""Utilities for visualizing a labeled dataset."""
[docs] def create_debug_images(
self, output_dir: str = "debug", sample_size: Optional[int] = None
):
"""Add rectangles for all labels and images in dataset."""
os.makedirs(os.path.join(self._base_path, output_dir), exist_ok=True)
x1_index = self.attributes.index("x1")
y1_index = self.attributes.index("y1")
x2_index = self.attributes.index("x2")
y2_index = self.attributes.index("y2")
has_ground_point = "x_ground" in self.attributes
if has_ground_point:
x_ground_index = self.attributes.index("x_ground")
y_ground_index = self.attributes.index("y_ground")
class_id_index = self.attributes.index("class_id")
# Use a colormap to provide different colors for all labels
color_map = plt.get_cmap("gist_rainbow")
colors = dict(
zip(
self.classes.keys(),
255 * color_map(np.linspace(0, 1, len(self.classes)))[..., :-1],
)
)
if sample_size is not None and sample_size < len(self.labels):
samples_keys = random.sample(sorted(self.labels.keys()), k=sample_size)
samples = {key: self.labels[key] for key in samples_keys}
else:
samples = self.labels
print("Create debug images ...")
for img_file, labels in tqdm(samples.items()):
input_path = os.path.join(self._base_path, img_file)
output_path = os.path.join(self._base_path, output_dir, img_file)
img = cv2.imread(input_path)
for label in labels:
# Get label data
class_id = label[class_id_index]
class_name = self.classes[class_id]
color = colors[class_id]
# Draw bounding box
cv2.rectangle(
img,
(label[x1_index], label[y1_index]),
(label[x2_index], label[y2_index]),
color,
2,
)
# Write class name at the top of the label
cv2.putText(
img,
text=class_name,
org=(label[x1_index], label[y1_index] - 5),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.9,
color=color,
thickness=2,
)
if has_ground_point:
# Draw bounding box
cv2.drawMarker(
img,
(label[x_ground_index], label[y_ground_index]),
color,
markerType=cv2.MARKER_CROSS,
markerSize=10,
)
# Save output
cv2.imwrite(output_path, img)