import itertools
from dataclasses import dataclass
from functools import cached_property
from typing import Iterable, Optional, Tuple
[docs]@dataclass(frozen=True)
class BoundingBox:
"""Class representing a bounding box.
The four coordinates are the absolute coordinates in the image.
"""
x1: float
"""Lower inclusive bound of the bounding box."""
y1: float
"""Left inclusive bound of the bounding box."""
x2: float
"""Upper inclusive bound of the bounding box."""
y2: float
"""Right inclusive bound of the bounding box."""
class_label: str
"""Class that the model gave the bounding box."""
confidence: float = None
"""Confidence of the model, that this bounding box is correct."""
def __getitem__(self, idx):
return self.coordinates[idx]
def __post_init__(self):
"""Assert that the bounding box is defined positive."""
assert self.x2 >= self.x1
assert self.y2 >= self.y1
@cached_property
def coordinates(self):
return self.x1, self.y1, self.x2, self.y2
@cached_property
def center_point(self) -> Tuple[float, float]:
return (self.x1 + self.x2) / 2, (self.y1 + self.y2) / 2
@cached_property
def width(self) -> float:
return self.x2 - self.x1
@cached_property
def height(self) -> float:
return self.y2 - self.y1
@cached_property
def area(self):
"""The area of the bounding box.
(width + 1)x(height + 1), because the boundaries are inside the bounding box.
"""
return (self.width + 1) * (self.height + 1)
[docs] def shift_and_scale(self, shift, scale):
"""Shift and scale the bounding box
Args:
shift: Vector that translates the bounding box.
scale: Vector that increases or decreases the bounding box.
Returns:
The scaled and shifted bounding box coordinates.
"""
x1 = shift[0] + scale[0] * self.x1
x2 = shift[0] + scale[0] * self.x2
y1 = shift[1] + scale[1] * self.y1
y2 = shift[1] + scale[1] * self.y2
return x1, y1, x2, y2
[docs] @classmethod
def create_bounding_boxes(
cls,
boxes: Iterable[Iterable[int]],
labels: Iterable[str],
scores: Optional[Iterable[float]] = None,
):
"""Create multiple bounding boxes."""
scores = scores if isinstance(scores, Iterable) else itertools.cycle([None])
return [cls(*box, label, score) for box, label, score in zip(boxes, labels, scores)]
[docs] @classmethod
def iou(cls, box_a: "BoundingBox", box_b: "BoundingBox") -> float:
"""Calculate the IOU(Intersection over Union) of two bounding boxes."""
# if boxes do not intersect
intersection_area = cls.intersection_area(box_a, box_b)
union = cls.union_area(box_a, box_b, intersection_area=intersection_area)
# intersection over union
return intersection_area / union
[docs] @classmethod
def union_area(
cls, box_a: "BoundingBox", box_b: "BoundingBox", intersection_area=None
) -> float:
"""Calculate the union area of two bounding boxes."""
area_a = box_a.area
area_b = box_b.area
if intersection_area is None:
intersection_area = cls.intersection_area(box_a, box_b)
return float(area_a + area_b - intersection_area)
[docs] @staticmethod
def intersection_area(box_a: "BoundingBox", box_b: "BoundingBox") -> float:
"""Calculate the intersection area of two bounding boxes."""
x_a = max(box_a[0], box_b[0])
y_a = max(box_a[1], box_b[1])
x_b = min(box_a[2], box_b[2])
y_b = min(box_a[3], box_b[3])
# There is no intersecting area, if the width or height is negative
return max(0, x_b - x_a + 1) * max(0, y_b - y_a + 1)