Source code for kitcar_ml.onboarding.script

from argparse import ArgumentParser
from typing import List, Tuple

import PIL.Image
import torch
import torch.nn
import torchvision

from kitcar_ml.onboarding.model import OnboardingNet
from kitcar_ml.utils.data.data_loader.base_data_loader import BaseDataLoader
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset


[docs]def get_dataloader(label_file: str, batch_size: int = 16) -> BaseDataLoader: """Create a dataloader. Return: Labeled dataloader object that can be used to load training/test data. """ dataset = LabeledDataset.from_yaml(label_file) transform_data = torchvision.transforms.Compose( (torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor()) ) def prepare_batch( batch: List[Tuple[PIL.Image.Image, List[Tuple[str, int]]]] ) -> Tuple[torch.Tensor, torch.Tensor]: """Turn a batch into the correct format for our model.""" images, labels = zip(*batch) # Turn each image into a grayscale tensor and # then stack all images into one large tensor image_tensors = torch.stack(list(transform_data(d) for d in images)) # Read only the class ids from the labels: # labels = [[[image1.png, 9]], [[image2.png, 1]]] # class labels: [9,1] class_labels = torch.tensor([t[0][1] for t in labels]) return image_tensors, class_labels return BaseDataLoader(dataset=dataset, batch_size=batch_size, collate_fn=prepare_batch)
[docs]def train(model: torch.nn.Module, label_file: str, epochs: int = 10): """Train the :py:class:`OnboardingNet` on the provided dataset. Args: model: The model to train. label_file: Path to the train yaml label file. epochs: Number of epochs to train. """ # TODO: train_loader = ... # TODO: optimizer = ... # TODO: loss_function = ... print("Start Training ...") # TODO: Write training loop print("Finished Training.\n")
[docs]@torch.no_grad() def test(model: torch.nn.Module, label_file: str): """Test the :py:class:`OnboardingNet` on the provided dataset. Args: model: The model to test. label_file: Path to the test yaml label file. """ test_loader = get_dataloader(label_file, batch_size=1) # Set model in testing / evaluation mode model.eval() # Count correct predictions correct = 0 print("Start Testing ...") for data, class_labels in test_loader: output = model(data) prediction = output.argmax( dim=1, keepdim=True ) # get the index of the max probability correct += prediction.eq(class_labels.view_as(prediction)).sum().item() correct_percentage = correct / len(test_loader.dataset) * 100 print(f"Accuracy: {correct_percentage:.1f}%") print("Finished Testing.")
if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--train", action="store_true", help="Train model") parser.add_argument( "--train_label_file", type=str, help="The training labels yaml file.", ) parser.add_argument("--test", action="store_true", help="Test model") parser.add_argument( "--test_label_file", type=str, help="The testing labels yaml file.", ) args = parser.parse_args() # Load model model = OnboardingNet() if args.train: train(model=model, label_file=args.train_label_file) if args.test: test(model=model, label_file=args.test_label_file)