Source code for kitcar_ml.utils.data.analyse_bbox_dataset

import os
from argparse import ArgumentParser
from functools import cached_property
from typing import List, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from tabulate import tabulate

from kitcar_ml.utils.bounding_box import BoundingBox
from kitcar_ml.utils.data.bbox_dataset import BBoxDataset


[docs]class AnalyseBBoxDataset(BBoxDataset):
[docs] def statistical_height( self, class_names: Optional[Union[str, List[str]]] = None ) -> Tuple[float, float]: """Get the average height of all bounding boxes. Args: class_names: Filter bounding boxes by class names. If None, use all bounding boxes. """ heights = [box.height for box in self.bboxes_by_class(class_names)] return np.mean(heights).item(), np.std(heights).item()
[docs] def statistical_width( self, class_names: Optional[Union[str, List[str]]] = None ) -> Tuple[float, float]: """Get the average and standard deviation width of all bounding boxes. Args: class_names: Filter bounding boxes by class names. If None, use all bounding boxes. """ widths = [box.width for box in self.bboxes_by_class(class_names)] return np.mean(widths).item(), np.std(widths).item()
[docs] def statistical_aspect_ratio( self, class_names: Optional[Union[str, List[str]]] = None ) -> Tuple[float, float]: """Get the average and standard deviation aspect ratio of all bounding boxes. Args: class_names: Filter bounding boxes by class names. If None, use all bounding boxes. """ ratios = [box.width / box.height for box in self.bboxes_by_class(class_names)] return np.mean(ratios).item(), np.std(ratios).item()
@cached_property def img_size(self) -> Tuple[int, int]: """Get Image size. Assumption: All Images have the same size. """ return self.__getitem__(0)[0].size
[docs] def bboxes_by_class( self, class_names: Optional[Union[str, List[str]]] = None ) -> List[BoundingBox]: """List bounding boxes. Args: class_names: Filter bounding boxes by class names. If None, use all bounding boxes. """ if class_names and not isinstance(class_names, list): class_names = [class_names] x1_index = self.attributes.index("x1") x2_index = self.attributes.index("x2") y1_index = self.attributes.index("y1") y2_index = self.attributes.index("y2") class_id_index = self.attributes.index("class_id") return [ BoundingBox( label[x1_index], label[y1_index], label[x2_index], label[y2_index], self.classes[label[class_id_index]], ) for labels in self.labels.values() for label in labels if class_names is None or self.classes[label[class_id_index]] in class_names ]
[docs] def draw_class_distribution( self, output_path: Optional[str] = None, show: bool = False ): """Draw class distribution. Args: output_path: Path for output figure show: Show the resulting diagram """ # Clear matplotlib plt.clf() # Get plotting data class_ids = list(self.classes.keys()) class_names = list(self.classes.values()) bar_heights = [len(self.bboxes_by_class(class_name)) for class_name in class_names] # Create plot fig, ax = plt.subplots(1, 1) # Describe plot ax.set_title("Class Distribution") ax.set_xlabel("Class ID") ax.set_ylabel("Num Targets") ax.set_xticks(class_ids) for i, (label, height) in enumerate(zip(class_names, bar_heights)): ax.text(i, height + 10, label, ha="center", rotation=90) # Add bars ax.bar(class_ids, bar_heights) if output_path: plt.savefig(output_path) if show: plt.get_current_fig_manager().window.showMaximized() plt.show()
[docs] def scatter_plot( self, class_names: Optional[Union[str, List[str]]] = None, output_path: Optional[str] = None, show: bool = False, ): """Draw scatter plot of bounding box center points. Args: class_names: Filter bounding boxes by class names. If None, use all bounding boxes. output_path: Path for output figure show: Show the resulting diagram """ # Clear matplotlib plt.clf() if class_names and not isinstance(class_names, list): class_names = [class_names] # Describe plot if class_names: plt.title( f"Scatter plot of Bounding Box center points " f"for class: {class_names[0] if len(class_names) == 1 else class_names}" ) else: plt.title("Scatter plot of Bounding Box center points") plt.xlabel("X") plt.ylabel("Y", rotation=0) plt.gca().invert_yaxis() # Add scatter center_points = [box.center_point for box in self.bboxes_by_class(class_names)] if len(center_points) == 0: print(f"No bounding boxes for class: {class_names}") return plt.scatter(*zip(*center_points)) if output_path: plt.savefig(output_path) if show: plt.get_current_fig_manager().window.showMaximized() plt.show()
[docs] def heatmap_plot( self, class_names: Optional[Union[str, List[str]]] = None, output_path: Optional[str] = None, show: bool = False, ): """Draw heatmap of bounding boxes. Args: class_names: Filter bounding boxes by class names. If None, use all bounding boxes. output_path: Path for output figure show: Show the resulting diagram """ # Clear matplotlib plt.clf() if class_names and not isinstance(class_names, list): class_names = [class_names] # Create Heatmap heatmap = np.zeros(self.img_size) bounding_boxes = self.bboxes_by_class(class_names) for bounding_box in bounding_boxes: # Add one in heatmap for all points within this bounding box heatmap[ bounding_box.x1 : bounding_box.x2, bounding_box.y1 : bounding_box.y2 ] += 1 # Describe plot if class_names: plt.title( f"Heatmap of Bounding Boxes " f"for class: {class_names[0] if len(class_names) == 1 else class_names}" ) else: plt.title("Heatmap of Bounding Boxes") plt.xlabel("X") plt.ylabel("Y", rotation=0) # Rotate and Flip, because plt.imshow needs shape [Height, Width] heatmap = np.rot90(heatmap) heatmap = np.flipud(heatmap) # Show Heatmap on Plot plt.imshow(heatmap, cmap="hot", interpolation="nearest") if show: plt.get_current_fig_manager().window.showMaximized() plt.show() if output_path: plt.savefig(output_path)
[docs] def basic_info(self) -> str: """Get some basic information about the dataset. Returns: Information about: Path, num classes, num bounding boxes, num images, image size """ out = "" out += f"Path to dataset: {self._base_path}\n" out += f"Number of Classes: {len(self.classes)}\n" out += f"Number of Bounding Boxes: {len(self.bboxes_by_class())}\n" out += f"Number of Images: {self.__len__()}\n" out += f"Image Size: {self.img_size}\n" return out
[docs] def report(self, output_folder: Optional[str] = "analysis"): """Collect information in a report. Args: output_folder: Path to store the report """ if output_folder: os.makedirs(os.path.join(output_folder, "scatter"), exist_ok=True) os.makedirs(os.path.join(output_folder, "heatmaps"), exist_ok=True) # Save Class Distribution self.draw_class_distribution( os.path.join(output_folder, "class_distribution.png"), show=False ) # Collect Statistics total_num_boxes = len(self.bboxes_by_class()) class_ids = list(self.classes.keys()) + [None] class_names = list(self.classes.values()) + [None] num_labels = [ f"{len(self.bboxes_by_class(name))}, " f"{100 * len(self.bboxes_by_class(name)) / total_num_boxes:.2f} %" for name in class_names ] widths = [ f"{stats[0]:.2f} +-{stats[1]:.2f}" for stats in (self.statistical_width(name) for name in class_names) ] heights = [ f"{stats[0]:.2f} +-{stats[1]:.2f}" for stats in (self.statistical_height(name) for name in class_names) ] aspect_ratios = [ f"{stats[0]:.2f} +-{stats[1]:.2f}" for stats in (self.statistical_aspect_ratio(name) for name in class_names) ] # Add to results results = [ class_ids, class_names, num_labels, widths, heights, aspect_ratios, ] headers = [ "ID", "Class Name", "Targets", "Width", "Height", "Aspect Ratio", ] # Create Scatter and Heatmaps if output_folder: for class_name in class_names: scatter_path = os.path.join( output_folder, "scatter", "total.png" if class_name is None else f"{class_name}.png", ) heatmap_path = os.path.join( output_folder, "heatmaps", "total.png" if class_name is None else f"{class_name}.png", ) self.scatter_plot(class_name, output_path=scatter_path) self.heatmap_plot(class_name, output_path=heatmap_path) # Rename None class to TOTAL class_names[-1] = "TOTAL" # Create Report table = tabulate(zip(*results), headers=headers, tablefmt="grid") report = "BBox Dataset Analysis\n" report += "---------------------\n\n" report += "Basic Information\n\n" report += f"{self.basic_info()}\n" report += "Detailed Information\n\n" report += table # Print Report print(report) # Save Report if output_folder: with open(os.path.join(output_folder, "report.txt"), "w+") as f: f.write(report)
if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--label-file", type=str, help="Path to dataset label file.") parser.add_argument( "--output-folder", type=str, default="analysis", help="Output folder for analysis." ) args = parser.parse_args() AnalyseBBoxDataset.from_yaml(args.label_file).report(args.output_folder)