import os
from argparse import ArgumentParser
from functools import cached_property
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from tabulate import tabulate
from kitcar_ml.utils.bounding_box import BoundingBox
from kitcar_ml.utils.data.bbox_dataset import BBoxDataset
[docs]class AnalyseBBoxDataset(BBoxDataset):
[docs] def statistical_height(
self, class_names: Optional[Union[str, List[str]]] = None
) -> Tuple[float, float]:
"""Get the average height of all bounding boxes.
Args:
class_names: Filter bounding boxes by class names.
If None, use all bounding boxes.
"""
heights = [box.height for box in self.bboxes_by_class(class_names)]
return np.mean(heights).item(), np.std(heights).item()
[docs] def statistical_width(
self, class_names: Optional[Union[str, List[str]]] = None
) -> Tuple[float, float]:
"""Get the average and standard deviation width of all bounding boxes.
Args:
class_names: Filter bounding boxes by class names.
If None, use all bounding boxes.
"""
widths = [box.width for box in self.bboxes_by_class(class_names)]
return np.mean(widths).item(), np.std(widths).item()
[docs] def statistical_aspect_ratio(
self, class_names: Optional[Union[str, List[str]]] = None
) -> Tuple[float, float]:
"""Get the average and standard deviation aspect ratio of all bounding boxes.
Args:
class_names: Filter bounding boxes by class names.
If None, use all bounding boxes.
"""
ratios = [box.width / box.height for box in self.bboxes_by_class(class_names)]
return np.mean(ratios).item(), np.std(ratios).item()
@cached_property
def img_size(self) -> Tuple[int, int]:
"""Get Image size.
Assumption: All Images have the same size.
"""
return self.__getitem__(0)[0].size
[docs] def bboxes_by_class(
self, class_names: Optional[Union[str, List[str]]] = None
) -> List[BoundingBox]:
"""List bounding boxes.
Args:
class_names: Filter bounding boxes by class names.
If None, use all bounding boxes.
"""
if class_names and not isinstance(class_names, list):
class_names = [class_names]
x1_index = self.attributes.index("x1")
x2_index = self.attributes.index("x2")
y1_index = self.attributes.index("y1")
y2_index = self.attributes.index("y2")
class_id_index = self.attributes.index("class_id")
return [
BoundingBox(
label[x1_index],
label[y1_index],
label[x2_index],
label[y2_index],
self.classes[label[class_id_index]],
)
for labels in self.labels.values()
for label in labels
if class_names is None or self.classes[label[class_id_index]] in class_names
]
[docs] def draw_class_distribution(
self, output_path: Optional[str] = None, show: bool = False
):
"""Draw class distribution.
Args:
output_path: Path for output figure
show: Show the resulting diagram
"""
# Clear matplotlib
plt.clf()
# Get plotting data
class_ids = list(self.classes.keys())
class_names = list(self.classes.values())
bar_heights = [len(self.bboxes_by_class(class_name)) for class_name in class_names]
# Create plot
fig, ax = plt.subplots(1, 1)
# Describe plot
ax.set_title("Class Distribution")
ax.set_xlabel("Class ID")
ax.set_ylabel("Num Targets")
ax.set_xticks(class_ids)
for i, (label, height) in enumerate(zip(class_names, bar_heights)):
ax.text(i, height + 10, label, ha="center", rotation=90)
# Add bars
ax.bar(class_ids, bar_heights)
if output_path:
plt.savefig(output_path)
if show:
plt.get_current_fig_manager().window.showMaximized()
plt.show()
[docs] def scatter_plot(
self,
class_names: Optional[Union[str, List[str]]] = None,
output_path: Optional[str] = None,
show: bool = False,
):
"""Draw scatter plot of bounding box center points.
Args:
class_names: Filter bounding boxes by class names.
If None, use all bounding boxes.
output_path: Path for output figure
show: Show the resulting diagram
"""
# Clear matplotlib
plt.clf()
if class_names and not isinstance(class_names, list):
class_names = [class_names]
# Describe plot
if class_names:
plt.title(
f"Scatter plot of Bounding Box center points "
f"for class: {class_names[0] if len(class_names) == 1 else class_names}"
)
else:
plt.title("Scatter plot of Bounding Box center points")
plt.xlabel("X")
plt.ylabel("Y", rotation=0)
plt.gca().invert_yaxis()
# Add scatter
center_points = [box.center_point for box in self.bboxes_by_class(class_names)]
if len(center_points) == 0:
print(f"No bounding boxes for class: {class_names}")
return
plt.scatter(*zip(*center_points))
if output_path:
plt.savefig(output_path)
if show:
plt.get_current_fig_manager().window.showMaximized()
plt.show()
[docs] def heatmap_plot(
self,
class_names: Optional[Union[str, List[str]]] = None,
output_path: Optional[str] = None,
show: bool = False,
):
"""Draw heatmap of bounding boxes.
Args:
class_names: Filter bounding boxes by class names.
If None, use all bounding boxes.
output_path: Path for output figure
show: Show the resulting diagram
"""
# Clear matplotlib
plt.clf()
if class_names and not isinstance(class_names, list):
class_names = [class_names]
# Create Heatmap
heatmap = np.zeros(self.img_size)
bounding_boxes = self.bboxes_by_class(class_names)
for bounding_box in bounding_boxes:
# Add one in heatmap for all points within this bounding box
heatmap[
bounding_box.x1 : bounding_box.x2, bounding_box.y1 : bounding_box.y2
] += 1
# Describe plot
if class_names:
plt.title(
f"Heatmap of Bounding Boxes "
f"for class: {class_names[0] if len(class_names) == 1 else class_names}"
)
else:
plt.title("Heatmap of Bounding Boxes")
plt.xlabel("X")
plt.ylabel("Y", rotation=0)
# Rotate and Flip, because plt.imshow needs shape [Height, Width]
heatmap = np.rot90(heatmap)
heatmap = np.flipud(heatmap)
# Show Heatmap on Plot
plt.imshow(heatmap, cmap="hot", interpolation="nearest")
if show:
plt.get_current_fig_manager().window.showMaximized()
plt.show()
if output_path:
plt.savefig(output_path)
[docs] def basic_info(self) -> str:
"""Get some basic information about the dataset.
Returns:
Information about: Path, num classes, num bounding boxes, num images, image size
"""
out = ""
out += f"Path to dataset: {self._base_path}\n"
out += f"Number of Classes: {len(self.classes)}\n"
out += f"Number of Bounding Boxes: {len(self.bboxes_by_class())}\n"
out += f"Number of Images: {self.__len__()}\n"
out += f"Image Size: {self.img_size}\n"
return out
[docs] def report(self, output_folder: Optional[str] = "analysis"):
"""Collect information in a report.
Args:
output_folder: Path to store the report
"""
if output_folder:
os.makedirs(os.path.join(output_folder, "scatter"), exist_ok=True)
os.makedirs(os.path.join(output_folder, "heatmaps"), exist_ok=True)
# Save Class Distribution
self.draw_class_distribution(
os.path.join(output_folder, "class_distribution.png"), show=False
)
# Collect Statistics
total_num_boxes = len(self.bboxes_by_class())
class_ids = list(self.classes.keys()) + [None]
class_names = list(self.classes.values()) + [None]
num_labels = [
f"{len(self.bboxes_by_class(name))}, "
f"{100 * len(self.bboxes_by_class(name)) / total_num_boxes:.2f} %"
for name in class_names
]
widths = [
f"{stats[0]:.2f} +-{stats[1]:.2f}"
for stats in (self.statistical_width(name) for name in class_names)
]
heights = [
f"{stats[0]:.2f} +-{stats[1]:.2f}"
for stats in (self.statistical_height(name) for name in class_names)
]
aspect_ratios = [
f"{stats[0]:.2f} +-{stats[1]:.2f}"
for stats in (self.statistical_aspect_ratio(name) for name in class_names)
]
# Add to results
results = [
class_ids,
class_names,
num_labels,
widths,
heights,
aspect_ratios,
]
headers = [
"ID",
"Class Name",
"Targets",
"Width",
"Height",
"Aspect Ratio",
]
# Create Scatter and Heatmaps
if output_folder:
for class_name in class_names:
scatter_path = os.path.join(
output_folder,
"scatter",
"total.png" if class_name is None else f"{class_name}.png",
)
heatmap_path = os.path.join(
output_folder,
"heatmaps",
"total.png" if class_name is None else f"{class_name}.png",
)
self.scatter_plot(class_name, output_path=scatter_path)
self.heatmap_plot(class_name, output_path=heatmap_path)
# Rename None class to TOTAL
class_names[-1] = "TOTAL"
# Create Report
table = tabulate(zip(*results), headers=headers, tablefmt="grid")
report = "BBox Dataset Analysis\n"
report += "---------------------\n\n"
report += "Basic Information\n\n"
report += f"{self.basic_info()}\n"
report += "Detailed Information\n\n"
report += table
# Print Report
print(report)
# Save Report
if output_folder:
with open(os.path.join(output_folder, "report.txt"), "w+") as f:
f.write(report)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--label-file", type=str, help="Path to dataset label file.")
parser.add_argument(
"--output-folder", type=str, default="analysis", help="Output folder for analysis."
)
args = parser.parse_args()
AnalyseBBoxDataset.from_yaml(args.label_file).report(args.output_folder)