from argparse import ArgumentParser
from typing import List, Tuple
import PIL.Image
import torch
import torch.nn
import torchvision
from kitcar_ml.onboarding.model import OnboardingNet
from kitcar_ml.utils.data.data_loader.base_data_loader import BaseDataLoader
from kitcar_ml.utils.data.labeled_dataset import LabeledDataset
[docs]def get_dataloader(label_file: str, batch_size: int = 16) -> BaseDataLoader:
"""Create a dataloader.
Return:
Labeled dataloader object that can be used to load training/test data.
"""
dataset = LabeledDataset.from_yaml(label_file)
transform_data = torchvision.transforms.Compose(
(torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor())
)
def prepare_batch(
batch: List[Tuple[PIL.Image.Image, List[Tuple[str, int]]]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Turn a batch into the correct format for our model."""
images, labels = zip(*batch)
# Turn each image into a grayscale tensor and
# then stack all images into one large tensor
image_tensors = torch.stack(list(transform_data(d) for d in images))
# Read only the class ids from the labels:
# labels = [[[image1.png, 9]], [[image2.png, 1]]]
# class labels: [9,1]
class_labels = torch.tensor([t[0][1] for t in labels])
return image_tensors, class_labels
return BaseDataLoader(dataset=dataset, batch_size=batch_size, collate_fn=prepare_batch)
[docs]def train(model: torch.nn.Module, label_file: str, epochs: int = 10):
"""Train the :py:class:`OnboardingNet` on the provided dataset.
Args:
model: The model to train.
label_file: Path to the train yaml label file.
epochs: Number of epochs to train.
"""
# TODO: train_loader = ...
# TODO: optimizer = ...
# TODO: loss_function = ...
print("Start Training ...")
# TODO: Write training loop
print("Finished Training.\n")
[docs]@torch.no_grad()
def test(model: torch.nn.Module, label_file: str):
"""Test the :py:class:`OnboardingNet` on the provided dataset.
Args:
model: The model to test.
label_file: Path to the test yaml label file.
"""
test_loader = get_dataloader(label_file, batch_size=1)
# Set model in testing / evaluation mode
model.eval()
# Count correct predictions
correct = 0
print("Start Testing ...")
for data, class_labels in test_loader:
output = model(data)
prediction = output.argmax(
dim=1, keepdim=True
) # get the index of the max probability
correct += prediction.eq(class_labels.view_as(prediction)).sum().item()
correct_percentage = correct / len(test_loader.dataset) * 100
print(f"Accuracy: {correct_percentage:.1f}%")
print("Finished Testing.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--train", action="store_true", help="Train model")
parser.add_argument(
"--train_label_file",
type=str,
help="The training labels yaml file.",
)
parser.add_argument("--test", action="store_true", help="Test model")
parser.add_argument(
"--test_label_file",
type=str,
help="The testing labels yaml file.",
)
args = parser.parse_args()
# Load model
model = OnboardingNet()
if args.train:
train(model=model, label_file=args.train_label_file)
if args.test:
test(model=model, label_file=args.test_label_file)