Source code for kitcar_ml.utils.data.visual_labeled_dataset

import os
import random
from typing import Optional

import cv2
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from kitcar_ml.utils.data.labeled_dataset import LabeledDataset


[docs]class VisualLabeledDataset(LabeledDataset): """Utilities for visualizing a labeled dataset."""
[docs] def create_debug_images( self, output_dir: str = "debug", sample_size: Optional[int] = None ): """Add rectangles for all labels and images in dataset.""" os.makedirs(os.path.join(self._base_path, output_dir), exist_ok=True) x1_index = self.attributes.index("x1") y1_index = self.attributes.index("y1") x2_index = self.attributes.index("x2") y2_index = self.attributes.index("y2") has_ground_point = "x_ground" in self.attributes if has_ground_point: x_ground_index = self.attributes.index("x_ground") y_ground_index = self.attributes.index("y_ground") class_id_index = self.attributes.index("class_id") # Use a colormap to provide different colors for all labels color_map = plt.get_cmap("gist_rainbow") colors = dict( zip( self.classes.keys(), 255 * color_map(np.linspace(0, 1, len(self.classes)))[..., :-1], ) ) if sample_size is not None and sample_size < len(self.labels): samples_keys = random.sample(sorted(self.labels.keys()), k=sample_size) samples = {key: self.labels[key] for key in samples_keys} else: samples = self.labels print("Create debug images ...") for img_file, labels in tqdm(samples.items()): input_path = os.path.join(self._base_path, img_file) output_path = os.path.join(self._base_path, output_dir, img_file) img = cv2.imread(input_path) for label in labels: # Get label data class_id = label[class_id_index] class_name = self.classes[class_id] color = colors[class_id] # Draw bounding box cv2.rectangle( img, (label[x1_index], label[y1_index]), (label[x2_index], label[y2_index]), color, 2, ) # Write class name at the top of the label cv2.putText( img, text=class_name, org=(label[x1_index], label[y1_index] - 5), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.9, color=color, thickness=2, ) if has_ground_point: # Draw bounding box cv2.drawMarker( img, (label[x_ground_index], label[y_ground_index]), color, markerType=cv2.MARKER_CROSS, markerSize=10, ) # Save output cv2.imwrite(output_path, img)