from abc import ABC, abstractmethod
from typing import List, Tuple, Type
import numpy as np
from torch.utils.data import DataLoader
from kitcar_ml.utils.bounding_box import BoundingBox
from kitcar_ml.utils.data.data_loader.bbox_data_loader import BBoxDataLoader
from kitcar_ml.utils.evaluation.evaluator import Evaluator
from kitcar_ml.utils.evaluation.simple_evaluator import SimpleEvaluator
[docs]class DetectionModel(ABC):
[docs] @abstractmethod
def fit(
self,
data_loader: DataLoader,
val_data_loader: DataLoader,
epochs: int = 10,
**kwargs
) -> None:
"""Train the model on the given data_loader."""
pass
[docs] @abstractmethod
def predict(
self, images: List[np.ndarray], **kwargs
) -> List[Tuple[List[np.ndarray], List[str], List[float]]]:
"""Take in a list of images and return predictions for object locations."""
pass
[docs] def evaluate(
self,
data_loader: BBoxDataLoader,
evaluator_type: Type[Evaluator] = SimpleEvaluator,
**kwargs
) -> Evaluator:
"""Run evaluator on data loader.
The evaluator_type can be used to run a different evaluator than SimpleEvaluator.
Returns:
The evaluator with results.
"""
detections = []
groundtruth = []
for images, bboxes, labels in data_loader:
groundtruth += [
BoundingBox.create_bounding_boxes(boxes, target_labels)
for boxes, target_labels in zip(bboxes, labels)
]
detections += [
BoundingBox.create_bounding_boxes(*prediction)
for prediction in self.predict(images, **kwargs)
]
evaluator = evaluator_type()
evaluator(groundtruth, detections)
return evaluator
[docs] @abstractmethod
def export_to_onnx(self, output_file: str):
"""Export this model into a onnx format.
Args:
output_file: Path to the output file
"""
pass
[docs] @abstractmethod
def save(self, file: str):
"""Save the internal model weights to a file.
Args:
file: The name of the file. Should have a .pth file extension.
"""
pass
[docs] @classmethod
@abstractmethod
def load(cls, file: str):
"""Load a model from a .pth file containing the model weights.
Args:
file: The path to the .pth file containing the saved model.
Returns:
The model loaded from the file.
"""
pass