from typing import List
import numpy as np
from kitcar_ml.utils.bounding_box import BoundingBox
from kitcar_ml.utils.evaluation.evaluator import Evaluator
[docs]class SimpleEvaluator(Evaluator):
def __init__(self, iou=0.5):
super().__init__()
self.iou_threshold = iou
self.average_precision = None
self.average_recall = None
self.f1score = None
def __call__(
self,
groundtruth: List[List[BoundingBox]] = None,
detections: List[List[BoundingBox]] = None,
):
"""Calculate the average_precision, average_recall and f1score."""
classes_bbs, gt_classes = self.split_bbs_per_class(groundtruth, detections)
precisions = []
recalls = []
for _, bb_dict in classes_bbs.items():
true_positive_vector = self.calculate_all_tp(
bb_dict["gt"], bb_dict["det"], self.iou_threshold
)
false_positive_vector = np.invert(true_positive_vector)
positives = sum(len(gt) for gt in bb_dict["gt"])
true_positives = sum(true_positive_vector)
false_positives = sum(false_positive_vector)
precisions.append(self.precision(true_positives, false_positives))
recalls.append(self.recall(true_positives, positives))
if len(precisions) == 0 and len(recalls) == 0:
average_precision = 1
average_recall = 1
f1score = 1
else:
average_precision = np.mean(precisions)
average_recall = np.mean(recalls)
f1score = self.calculate_f1score(average_precision, average_recall)
self.average_precision = average_precision
self.average_recall = average_recall
self.f1score = f1score
def __str__(self):
return (
f"average_rec: {self.average_recall}\n"
f"average_prec: {self.average_precision}\n"
f"f1score: {self.f1score}"
)
[docs] @staticmethod
def calculate_f1score(precision: float, recall: float) -> float:
"""Calculate the f1score of the precision and recall."""
return (
0
if recall + precision == 0
else (2 * precision * recall / (precision + recall))
)
[docs] @staticmethod
def precision(true_positive: int, false_positive: int) -> float:
"""Calculate the precision of the true positive and false positive."""
return (
0
if true_positive + false_positive == 0
else true_positive / (true_positive + false_positive)
)
[docs] @staticmethod
def recall(true_positives: int, ngts: int) -> float:
"""Calculate the recall.
Args:
true_positives: True positives
ngts: Number of groundtruth labels
"""
return 1 if ngts == 0 else true_positives / ngts