Source code for kitcar_ml.utils.data.bbox_dataset

import torch
from albumentations.core import bbox_utils

from kitcar_ml.utils.data.labeled_dataset import LabeledDataset


[docs]class BBoxDataset(LabeledDataset): def __getitem__(self, index: int): if self._cache is not None: return self._cache[index] image, labels = super().__getitem__(index) class_id_index = self.attributes.index("class_id") x1_index = self.attributes.index("x1") y1_index = self.attributes.index("y1") x2_index = self.attributes.index("x2") y2_index = self.attributes.index("y2") img_width, img_height = image.size boxes = [] target_labels = [] for label in labels: x1, y1, x2, y2 = ( label[x1_index], label[y1_index], label[x2_index], label[y2_index], ) # Clamp x/y to image _, x1, x2, _ = sorted((0, x1, x2, img_width)) _, y1, y2, _ = sorted((0, y1, y2, img_height)) bbox = [x1, y1, x2, y2] try: bbox_utils.check_bbox( bbox_utils.normalize_bbox(bbox, rows=img_height, cols=img_width) ) except ValueError as e: print(f"Warning: Dropping invalid bbox: {e}") continue boxes.append(bbox) # Read in the label target_labels.append(self.classes[label[class_id_index]]) boxes = torch.tensor(boxes).view(-1, 4) return image, boxes, target_labels
[docs] def prune_small_boxes(self, min_area: int): """Keep only boxes with a minimal area.""" x1_index = self.attributes.index("x1") y1_index = self.attributes.index("y1") x2_index = self.attributes.index("x2") y2_index = self.attributes.index("y2") def get_area(label): return abs( (label[x2_index] - label[x1_index]) * (label[y2_index] - label[y1_index]) ) for key, labels in self.labels.items(): self.labels[key] = [label for label in labels if get_area(label) >= min_area]