import itertools
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
from kitcar_ml.utils.bounding_box import BoundingBox
from kitcar_ml.utils.evaluation.evaluator import Evaluator
[docs]@dataclass
class InterpolationResult:
"""Data class containing the interpolation results."""
recall: List[float]
"""List with the recall values."""
precision: List[float]
"""List with the precision values."""
ap: float
"""Average precision."""
recall_interpolation: List[float]
"""Interpolated recall values."""
precision_interpolation: List[float]
"""Interpolated precision values."""
total_positives: int
"""Total number of ground truth positives."""
true_positives: int
"""Number of true positive detections."""
false_positives: int
"""Number of false positive detections."""
[docs]def head(iterable):
"""Returns the first element of a list."""
return iterable[0]
[docs]class InterpolationEvaluator(Evaluator):
def __init__(
self,
iou_thresholds: Tuple[float, ...] = (0.5, 0.75, 0.95),
use_every_point_interpolation: bool = True,
):
super().__init__()
self.iou_thresholds = iou_thresholds
self.results = None
self.m_ap = None
self.use_every_point_interpolation = use_every_point_interpolation
def __call__(
self,
groundtruth: List[List[BoundingBox]],
detections: List[List[BoundingBox]],
classes: Union[List[str], str] = "all",
):
"""Evaluate the detections with the given groundtruth and plot the pascal voc
metrics per image.
Args:
groundtruth: Is a list of BoundingBox that contain the groundtruth data.
detections: Is a list of BoundingBox that contain the detections from the model.
classes: Classes that should be considered.
"""
voc_res = {}
m_ap_vector = {}
for iou in self.iou_thresholds:
voc_res[iou], m_ap_vector[iou] = self.calculate_results(
groundtruth, detections, classes, iou
)
self.results = voc_res
self.m_ap = m_ap_vector
[docs] @staticmethod
def _class_label_string(pair, name) -> str:
return f"{pair[1]} is {name} class, with {round(100*pair[0])}% average precision."
def __str__(self) -> str:
spacer = "\n----------------\n"
return spacer.join(
f"IoU {iou}: \nmAP: {self.m_ap[iou]}" + spacer
for iou, class_dict in self.results.items()
)
[docs] @classmethod
def calculate_ap_every_point(
cls, recall_vector: np.ndarray, precision_vector: np.ndarray
) -> Tuple[float, List[float], List[float]]:
"""Interpolate ap for every point.
Args:
recall_vector: numpy array of recalls
precision_vector: numpy array of precision
Returns:
ap: The average precision.
recall_interpolation: The interpolated recall
precision_interpolation: The interpolated precision
"""
# Add buffer values around recall and precision
recall_interpolation = [0, *recall_vector, 1]
precision_interpolation = [0, *precision_vector, 0]
# Interpolate the precision.
reverse_accumulated_precision = itertools.accumulate(
precision_interpolation[::-1], func=max
)
precision_interpolation = list(reverse_accumulated_precision)[::-1]
# The area under the curve is the ap, because it interpolates the ratio of the
# true positives and false positives.
# Calculate the lengths of the areas under the curve
recall_differences = np.diff(recall_interpolation)
# Multiply it with the height and sum them.
average_precision = sum(recall_differences * precision_interpolation[1:])
# Last element was just for calculations.
return (
average_precision,
recall_interpolation[:-1],
precision_interpolation[:-1],
)
[docs] @staticmethod
def calculate_interpolation_points(
recall_interpolation: List[float], precision_interpolation: List[float]
) -> List[Tuple[int, int]]:
"""Calculate the interpolated points for the recall and the precision. The maximal
precision is used for equal recall value.
Args:
recall_interpolation: The interpolation of the recall values
precision_interpolation: The interpolation of the precision values
"""
recall_values = [1, *recall_interpolation, 0]
precision_values = [0, 0, *precision_interpolation, 0]
interpolation_dict = defaultdict(int)
# Find the maximum precision value for each previous recall and current recall value
for r, precision, previous_prec in zip(
recall_values, precision_values[1:], precision_values[:-1]
):
interpolation_dict[r] = max(interpolation_dict[r], precision, previous_prec)
return sorted(interpolation_dict.items(), key=head, reverse=True)
[docs] @classmethod
def calculate_ap_11_point_interp(
cls, recall_vector: np.ndarray, precision_vector: np.ndarray
) -> Tuple[float, List[float], List[float]]:
"""Interpolate recall and precision at eleven points.
Args:
recall_vector: numpy array of recall values
precision_vector: numpy array of precision values
Returns:
ap: The average precision.
recall_interpolation: The interpolated recall
precision_interpolation: The interpolated precision
"""
recall_interpolation = list(np.linspace(1, 0, 11))
# Calculate the index where the recall is higher than the interpolation.
start_indices = [
next((i for i, recall in enumerate(recall_vector) if recall >= r), -1)
for r in recall_interpolation
]
# Calculate maximum precision starting with the matching recall index.
precision_interpolation = [
0 if i == -1 else max(precision_vector[i:]) for i in start_indices
]
interpolated_points = cls.calculate_interpolation_points(
recall_interpolation, precision_interpolation
)
# Unpack list of tuples.
recall_interpolation, precision_interpolation = zip(*interpolated_points)
ap = sum(precision_interpolation) / len(recall_interpolation)
return ap, recall_interpolation, precision_interpolation
[docs] @staticmethod
def calculate_sorted_prefix_sum(
detections: List[List[BoundingBox]],
true_positives: np.ndarray,
false_positives: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Sorts the true and false positive arrays and calculates the prefix sum.
Args:
detections: List of all detections from this class.
true_positives: Array that defines the true positive.
false_positives: Array that defines the false positive.
"""
permutation = np.argsort([-1 * bb.confidence for dets in detections for bb in dets])
return np.cumsum(true_positives[permutation]), np.cumsum(
false_positives[permutation]
)
[docs] @classmethod
def calculate_metrics(
cls,
groundtruth: List[List[BoundingBox]],
detections: List[List[BoundingBox]],
true_positives: np.ndarray,
calculate_interpolation,
) -> InterpolationResult:
"""Calculate the metrics for the interpolation.
Args:
groundtruth: List that contain the groundtruth bounding boxes.
detections: List that contain the detections bounding boxes.
true_positives: Array of the true positive value for each detection.
calculate_interpolation: The function that calculates the interpolation.
"""
positives = sum(len(gt) for gt in groundtruth)
false_positives = np.invert(true_positives)
tp_prefix_sum, fp_prefix_sum = cls.calculate_sorted_prefix_sum(
detections, true_positives, false_positives
)
precision = np.divide(tp_prefix_sum, tp_prefix_sum + fp_prefix_sum)
recall = tp_prefix_sum / positives
ap, recall_interpolation, precision_interpolation = calculate_interpolation(
recall, precision
)
return InterpolationResult(
recall,
precision,
ap,
recall_interpolation,
precision_interpolation,
positives,
0 if len(tp_prefix_sum) == 0 else tp_prefix_sum[-1],
0 if len(fp_prefix_sum) == 0 else fp_prefix_sum[-1],
)
[docs] @classmethod
def calculate_class_results(
cls,
groundtruth: List[List[List[BoundingBox]]],
detections: List[List[List[BoundingBox]]],
calculate_interpolation,
iou_threshold: float,
) -> InterpolationResult:
"""Calculate the interpolation, true positives and false positives for groundtruth
and detection.
Args:
groundtruth: List that contain the groundtruth bounding boxes.
detections: List that contain the detections bounding boxes.
iou_threshold: The threshold that bounds the acceptance of a detection.
calculate_interpolation: The function that calculates the interpolation.
"""
true_positives = [
cls.calculate_all_tp(gts, dets, iou_threshold)
for gts, dets in zip(groundtruth, detections)
]
true_positives = np.concatenate(true_positives)
ret = cls.calculate_metrics(
sum(groundtruth, []),
sum(detections, []),
true_positives,
calculate_interpolation,
)
return ret
[docs] def calculate_results(
self,
groundtruth: List[List[BoundingBox]],
detections: List[List[BoundingBox]],
classes: List[str],
iou_threshold: float = 0.5,
):
"""Calculate the metrics of all classes.
Args:
groundtruth: List of BoundingBoxes representing groundtruth bounding boxes;
detections: List of BoundingBoxes representing detections bounding boxes;
iou_threshold: IOU threshold indicating which detections
will be considered TP or FP
Returns:
A result dictionary for every class.
"""
ret = {}
calculate_interpolation_func = (
self.calculate_ap_every_point
if self.use_every_point_interpolation
else self.calculate_ap_11_point_interp
)
classes_bbs, self.gt_classes = self.split_bbs_per_class(groundtruth, detections)
if len(self.gt_classes) == 0:
return {}, 0 if len(sum(detections, [])) > 0 else 1
# Create result for all classes
all_bb_dicts = defaultdict(list)
for class_, bb_dict in classes_bbs.items():
if class_ in classes and class_ in self.gt_classes or classes == "all":
for key, value in bb_dict.items():
all_bb_dicts[key].append(value)
if len(all_bb_dicts.keys()) == 0:
print("No valid classes given.")
return None, -1
name = ", ".join(classes)
if len(classes) == len(self.gt_classes) or classes == "all":
name = "all"
ret[name] = self.calculate_class_results(
all_bb_dicts["gt"],
all_bb_dicts["det"],
calculate_interpolation_func,
iou_threshold,
)
return ret, ret[name].ap
[docs] def plot_precision_recall_curves(
self,
classes=None,
show_interpolated_precision: bool = True,
save_path: str = None,
save_prefix: str = "plot",
show_graphic: bool = True,
):
"""Plot the precision and recall curve.
Args:
classes: The classes that should be plotted, "all" can be a class.
show_interpolated_precision: True if the interpolation should be shown.
save_path: Save path of the plots, plots are not saved if no path is given.
save_prefix: The prefix of all saved files.
show_graphic: True if the plots should be shown.
"""
if not classes:
classes = self.gt_classes | {"all"}
if self.use_every_point_interpolation:
label = "Interpolated precision"
plot_type = "--r"
else:
label = "11-point interpolated precision"
plot_type = "or"
for iou in self.iou_thresholds:
for class_label, result in self.results[iou].items():
if class_label in classes:
title = f"Precision/Recall curve,\nClass: {class_label} IoU: {iou}"
ndets = result.false_positives + result.true_positives
txt = "".join(
(
f"AP: {round(result.ap * 100, 1)}% ",
f"detections: {ndets}, groundtruths:",
f"{result.total_positives}",
)
)
if result is None:
raise OSError(f"Error: Class {class_label} could not be found.")
plt.close()
fig = plt.figure()
ax = fig.add_axes((0.1, 0.2, 0.8, 0.7))
if show_interpolated_precision:
ax.plot(
result.recall_interpolation,
result.precision_interpolation,
plot_type,
label=label,
)
ax.plot(result.recall, result.precision, label="Precision")
plt.xlabel("recall")
plt.ylabel("precision")
fig.text(0.5, 0.05, txt, ha="center")
plt.title(title)
plt.legend(shadow=True)
plt.grid()
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
path = os.path.join(
save_path, f"{save_prefix}_{iou}_{class_label}.png"
)
plt.xlim(-0.05, 1)
plt.ylim(-0.05, 1.05)
plt.savefig(path)
print(f"saved to {path}")
if show_graphic:
plt.show()
plt.pause(0.05)