Source code for kitcar_ml.traffic_sign_detection.detection_model

from abc import ABC, abstractmethod
from typing import List, Tuple, Type

import numpy as np
from torch.utils.data import DataLoader

from kitcar_ml.utils.bounding_box import BoundingBox
from kitcar_ml.utils.data.data_loader.bbox_data_loader import BBoxDataLoader
from kitcar_ml.utils.evaluation.evaluator import Evaluator
from kitcar_ml.utils.evaluation.simple_evaluator import SimpleEvaluator


[docs]class DetectionModel(ABC):
[docs] @abstractmethod def fit( self, data_loader: DataLoader, val_data_loader: DataLoader, epochs: int = 10, **kwargs ) -> None: """Train the model on the given data_loader.""" pass
[docs] @abstractmethod def predict( self, images: List[np.ndarray], **kwargs ) -> List[Tuple[List[np.ndarray], List[str], List[float]]]: """Take in a list of images and return predictions for object locations.""" pass
[docs] def evaluate( self, data_loader: BBoxDataLoader, evaluator_type: Type[Evaluator] = SimpleEvaluator, **kwargs ) -> Evaluator: """Run evaluator on data loader. The evaluator_type can be used to run a different evaluator than SimpleEvaluator. Returns: The evaluator with results. """ detections = [] groundtruth = [] for images, bboxes, labels in data_loader: groundtruth += [ BoundingBox.create_bounding_boxes(boxes, target_labels) for boxes, target_labels in zip(bboxes, labels) ] detections += [ BoundingBox.create_bounding_boxes(*prediction) for prediction in self.predict(images, **kwargs) ] evaluator = evaluator_type() evaluator(groundtruth, detections) return evaluator
[docs] @abstractmethod def export_to_onnx(self, output_file: str): """Export this model into a onnx format. Args: output_file: Path to the output file """ pass
[docs] @abstractmethod def save(self, file: str): """Save the internal model weights to a file. Args: file: The name of the file. Should have a .pth file extension. """ pass
[docs] @classmethod @abstractmethod def load(cls, file: str): """Load a model from a .pth file containing the model weights. Args: file: The path to the .pth file containing the saved model. Returns: The model loaded from the file. """ pass