import random
from typing import List, Tuple
from kitcar_ml.utils.bounding_box import BoundingBox
[docs]def get_shift_scale(error: float = 0.5) -> Tuple[Tuple[float, float], Tuple[float, float]]:
"""Shift and scale the bounding box that the overlap is not 100% anymore."""
shift = (
random.uniform(-10 * error, 10 * error),
random.uniform(-10 * error, 10 * error),
)
scale = (random.uniform(1 - error, 1 + error), random.uniform(1 - error, 1 + error))
return shift, scale
[docs]def modify_bounding_box(
bb,
error,
accuracy,
shift,
scale,
):
return BoundingBox(
*bb.shift_and_scale(shift, scale),
bb.class_label,
confidence=random.uniform(accuracy - error, 1),
)
[docs]def predict_with_overlap(
image,
gt_bbs: List[BoundingBox],
accuracy: float = 0.5,
additional_labels: int = 0,
error: float = 0.5,
) -> List[BoundingBox]:
"""Simulate a prediction of a model, the bounding box is defined with 2 points: min,
max. bb=[xmin, ymin, xmax, ymax]
Args:
image: The input image as a tensor
gt_bbs: The list of groundtruth bounding boxes.
accuracy: the accuracy of the model
additional_labels: Bounding boxes that are added to the detections.
error: The probability of the error every bounding box
Returns:
the list of found bounding boxes with labels according to the accuracy
"""
det_bbs = [
modify_bounding_box(gt, error, accuracy, *get_shift_scale(error))
if random.random() < accuracy
else modify_bounding_box(
gt, error, accuracy, (2 * (gt.x2 - gt.x1), 2 * (gt.y2 - gt.y1)), (1, 1)
)
for gt in gt_bbs
]
fp_labels = [
BoundingBox(
10,
10,
image.shape[1],
image.shape[2],
gt_bbs[0].class_label,
)
for _ in range(additional_labels)
]
for bb in fp_labels:
modify_bounding_box(bb, error, accuracy, *get_shift_scale(error))
return det_bbs + fp_labels