import torch
from albumentations.core import bbox_utils
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset
[docs]class BBoxDataset(LabeledDataset):
def __getitem__(self, index: int):
if self._cache is not None:
return self._cache[index]
image, labels = super().__getitem__(index)
class_id_index = self.attributes.index("class_id")
x1_index = self.attributes.index("x1")
y1_index = self.attributes.index("y1")
x2_index = self.attributes.index("x2")
y2_index = self.attributes.index("y2")
img_width, img_height = image.size
boxes = []
target_labels = []
for label in labels:
x1, y1, x2, y2 = (
label[x1_index],
label[y1_index],
label[x2_index],
label[y2_index],
)
# Clamp x/y to image
_, x1, x2, _ = sorted((0, x1, x2, img_width))
_, y1, y2, _ = sorted((0, y1, y2, img_height))
bbox = [x1, y1, x2, y2]
try:
bbox_utils.check_bbox(
bbox_utils.normalize_bbox(bbox, rows=img_height, cols=img_width)
)
except ValueError as e:
print(f"Warning: Dropping invalid bbox: {e}")
continue
boxes.append(bbox)
# Read in the label
target_labels.append(self.classes[label[class_id_index]])
boxes = torch.tensor(boxes).view(-1, 4)
return image, boxes, target_labels
[docs] def prune_small_boxes(self, min_area: int):
"""Keep only boxes with a minimal area."""
x1_index = self.attributes.index("x1")
y1_index = self.attributes.index("y1")
x2_index = self.attributes.index("x2")
y2_index = self.attributes.index("y2")
def get_area(label):
return abs(
(label[x2_index] - label[x1_index]) * (label[y2_index] - label[y1_index])
)
for key, labels in self.labels.items():
self.labels[key] = [label for label in labels if get_area(label) >= min_area]