Source code for kitcar_ml.utils.evaluation.interpolation_evaluator

import itertools
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np

from kitcar_ml.utils.bounding_box import BoundingBox
from kitcar_ml.utils.evaluation.evaluator import Evaluator


[docs]@dataclass class InterpolationResult: """Data class containing the interpolation results.""" recall: List[float] """List with the recall values.""" precision: List[float] """List with the precision values.""" ap: float """Average precision.""" recall_interpolation: List[float] """Interpolated recall values.""" precision_interpolation: List[float] """Interpolated precision values.""" total_positives: int """Total number of ground truth positives.""" true_positives: int """Number of true positive detections.""" false_positives: int """Number of false positive detections."""
[docs]class InterpolationEvaluator(Evaluator): def __init__( self, iou_thresholds: Tuple[float, ...] = (0.5, 0.75, 0.95), use_every_point_interpolation: bool = True, ): super().__init__() self.iou_thresholds = iou_thresholds self.results = None self.m_ap = None self.use_every_point_interpolation = use_every_point_interpolation def __call__( self, groundtruth: List[List[BoundingBox]], detections: List[List[BoundingBox]], classes: Union[List[str], str] = "all", ): """Evaluate the detections with the given groundtruth and plot the pascal voc metrics per image. Args: groundtruth: Is a list of BoundingBox that contain the groundtruth data. detections: Is a list of BoundingBox that contain the detections from the model. classes: Classes that should be considered. """ voc_res = {} m_ap_vector = {} for iou in self.iou_thresholds: voc_res[iou], m_ap_vector[iou] = self.calculate_results( groundtruth, detections, classes, iou ) self.results = voc_res self.m_ap = m_ap_vector
[docs] @staticmethod def _class_label_string(pair, name) -> str: return f"{pair[1]} is {name} class, with {round(100*pair[0])}% average precision."
def __str__(self) -> str: spacer = "\n----------------\n" return spacer.join( f"IoU {iou}: \nmAP: {self.m_ap[iou]}" + spacer for iou, class_dict in self.results.items() )
[docs] @classmethod def calculate_ap_every_point( cls, recall_vector: np.ndarray, precision_vector: np.ndarray ) -> Tuple[float, List[float], List[float]]: """Interpolate ap for every point. Args: recall_vector: numpy array of recalls precision_vector: numpy array of precision Returns: ap: The average precision. recall_interpolation: The interpolated recall precision_interpolation: The interpolated precision """ # Add buffer values around recall and precision recall_interpolation = [0, *recall_vector, 1] precision_interpolation = [0, *precision_vector, 0] # Interpolate the precision. reverse_accumulated_precision = itertools.accumulate( precision_interpolation[::-1], func=max ) precision_interpolation = list(reverse_accumulated_precision)[::-1] # The area under the curve is the ap, because it interpolates the ratio of the # true positives and false positives. # Calculate the lengths of the areas under the curve recall_differences = np.diff(recall_interpolation) # Multiply it with the height and sum them. average_precision = sum(recall_differences * precision_interpolation[1:]) # Last element was just for calculations. return ( average_precision, recall_interpolation[:-1], precision_interpolation[:-1], )
[docs] @staticmethod def calculate_interpolation_points( recall_interpolation: List[float], precision_interpolation: List[float] ) -> List[Tuple[int, int]]: """Calculate the interpolated points for the recall and the precision. The maximal precision is used for equal recall value. Args: recall_interpolation: The interpolation of the recall values precision_interpolation: The interpolation of the precision values """ recall_values = [1, *recall_interpolation, 0] precision_values = [0, 0, *precision_interpolation, 0] interpolation_dict = defaultdict(int) # Find the maximum precision value for each previous recall and current recall value for r, precision, previous_prec in zip( recall_values, precision_values[1:], precision_values[:-1] ): interpolation_dict[r] = max(interpolation_dict[r], precision, previous_prec) return sorted(interpolation_dict.items(), key=head, reverse=True)
[docs] @classmethod def calculate_ap_11_point_interp( cls, recall_vector: np.ndarray, precision_vector: np.ndarray ) -> Tuple[float, List[float], List[float]]: """Interpolate recall and precision at eleven points. Args: recall_vector: numpy array of recall values precision_vector: numpy array of precision values Returns: ap: The average precision. recall_interpolation: The interpolated recall precision_interpolation: The interpolated precision """ recall_interpolation = list(np.linspace(1, 0, 11)) # Calculate the index where the recall is higher than the interpolation. start_indices = [ next((i for i, recall in enumerate(recall_vector) if recall >= r), -1) for r in recall_interpolation ] # Calculate maximum precision starting with the matching recall index. precision_interpolation = [ 0 if i == -1 else max(precision_vector[i:]) for i in start_indices ] interpolated_points = cls.calculate_interpolation_points( recall_interpolation, precision_interpolation ) # Unpack list of tuples. recall_interpolation, precision_interpolation = zip(*interpolated_points) ap = sum(precision_interpolation) / len(recall_interpolation) return ap, recall_interpolation, precision_interpolation
[docs] @staticmethod def calculate_sorted_prefix_sum( detections: List[List[BoundingBox]], true_positives: np.ndarray, false_positives: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: """Sorts the true and false positive arrays and calculates the prefix sum. Args: detections: List of all detections from this class. true_positives: Array that defines the true positive. false_positives: Array that defines the false positive. """ permutation = np.argsort([-1 * bb.confidence for dets in detections for bb in dets]) return np.cumsum(true_positives[permutation]), np.cumsum( false_positives[permutation] )
[docs] @classmethod def calculate_metrics( cls, groundtruth: List[List[BoundingBox]], detections: List[List[BoundingBox]], true_positives: np.ndarray, calculate_interpolation, ) -> InterpolationResult: """Calculate the metrics for the interpolation. Args: groundtruth: List that contain the groundtruth bounding boxes. detections: List that contain the detections bounding boxes. true_positives: Array of the true positive value for each detection. calculate_interpolation: The function that calculates the interpolation. """ positives = sum(len(gt) for gt in groundtruth) false_positives = np.invert(true_positives) tp_prefix_sum, fp_prefix_sum = cls.calculate_sorted_prefix_sum( detections, true_positives, false_positives ) precision = np.divide(tp_prefix_sum, tp_prefix_sum + fp_prefix_sum) recall = tp_prefix_sum / positives ap, recall_interpolation, precision_interpolation = calculate_interpolation( recall, precision ) return InterpolationResult( recall, precision, ap, recall_interpolation, precision_interpolation, positives, 0 if len(tp_prefix_sum) == 0 else tp_prefix_sum[-1], 0 if len(fp_prefix_sum) == 0 else fp_prefix_sum[-1], )
[docs] @classmethod def calculate_class_results( cls, groundtruth: List[List[List[BoundingBox]]], detections: List[List[List[BoundingBox]]], calculate_interpolation, iou_threshold: float, ) -> InterpolationResult: """Calculate the interpolation, true positives and false positives for groundtruth and detection. Args: groundtruth: List that contain the groundtruth bounding boxes. detections: List that contain the detections bounding boxes. iou_threshold: The threshold that bounds the acceptance of a detection. calculate_interpolation: The function that calculates the interpolation. """ true_positives = [ cls.calculate_all_tp(gts, dets, iou_threshold) for gts, dets in zip(groundtruth, detections) ] true_positives = np.concatenate(true_positives) ret = cls.calculate_metrics( sum(groundtruth, []), sum(detections, []), true_positives, calculate_interpolation, ) return ret
[docs] def calculate_results( self, groundtruth: List[List[BoundingBox]], detections: List[List[BoundingBox]], classes: List[str], iou_threshold: float = 0.5, ): """Calculate the metrics of all classes. Args: groundtruth: List of BoundingBoxes representing groundtruth bounding boxes; detections: List of BoundingBoxes representing detections bounding boxes; iou_threshold: IOU threshold indicating which detections will be considered TP or FP Returns: A result dictionary for every class. """ ret = {} calculate_interpolation_func = ( self.calculate_ap_every_point if self.use_every_point_interpolation else self.calculate_ap_11_point_interp ) classes_bbs, self.gt_classes = self.split_bbs_per_class(groundtruth, detections) if len(self.gt_classes) == 0: return {}, 0 if len(sum(detections, [])) > 0 else 1 # Create result for all classes all_bb_dicts = defaultdict(list) for class_, bb_dict in classes_bbs.items(): if class_ in classes and class_ in self.gt_classes or classes == "all": for key, value in bb_dict.items(): all_bb_dicts[key].append(value) if len(all_bb_dicts.keys()) == 0: print("No valid classes given.") return None, -1 name = ", ".join(classes) if len(classes) == len(self.gt_classes) or classes == "all": name = "all" ret[name] = self.calculate_class_results( all_bb_dicts["gt"], all_bb_dicts["det"], calculate_interpolation_func, iou_threshold, ) return ret, ret[name].ap
[docs] def plot_precision_recall_curves( self, classes=None, show_interpolated_precision: bool = True, save_path: str = None, save_prefix: str = "plot", show_graphic: bool = True, ): """Plot the precision and recall curve. Args: classes: The classes that should be plotted, "all" can be a class. show_interpolated_precision: True if the interpolation should be shown. save_path: Save path of the plots, plots are not saved if no path is given. save_prefix: The prefix of all saved files. show_graphic: True if the plots should be shown. """ if not classes: classes = self.gt_classes | {"all"} if self.use_every_point_interpolation: label = "Interpolated precision" plot_type = "--r" else: label = "11-point interpolated precision" plot_type = "or" for iou in self.iou_thresholds: for class_label, result in self.results[iou].items(): if class_label in classes: title = f"Precision/Recall curve,\nClass: {class_label} IoU: {iou}" ndets = result.false_positives + result.true_positives txt = "".join( ( f"AP: {round(result.ap * 100, 1)}% ", f"detections: {ndets}, groundtruths:", f"{result.total_positives}", ) ) if result is None: raise OSError(f"Error: Class {class_label} could not be found.") plt.close() fig = plt.figure() ax = fig.add_axes((0.1, 0.2, 0.8, 0.7)) if show_interpolated_precision: ax.plot( result.recall_interpolation, result.precision_interpolation, plot_type, label=label, ) ax.plot(result.recall, result.precision, label="Precision") plt.xlabel("recall") plt.ylabel("precision") fig.text(0.5, 0.05, txt, ha="center") plt.title(title) plt.legend(shadow=True) plt.grid() if save_path is not None: os.makedirs(save_path, exist_ok=True) path = os.path.join( save_path, f"{save_prefix}_{iou}_{class_label}.png" ) plt.xlim(-0.05, 1) plt.ylim(-0.05, 1.05) plt.savefig(path) print(f"saved to {path}") if show_graphic: plt.show() plt.pause(0.05)