Source code for kitcar_ml.onboarding.model

import torch
import torch.nn
import torch.nn.functional


[docs]class OnboardingNet(torch.nn.Module): """Neural network class. The :func:`forward` defines how images are passed through the network. """ def __init__(self): super().__init__() # Define fully connected layer self.fc = torch.nn.Linear(28 * 28, 10)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Handle input. Args: x: Input image Return: Output class as 10-dimensional tensor. """ x = torch.flatten(x, 1) x = self.fc(x) output = torch.nn.functional.softmax(x, dim=1) return output
[docs] def load(self, path: str): """Load model parameters from file. Args: path: Path to file """ # TODO pass
[docs] def save(self, path: str): """Save model parameters to file. Args: path: Path to file """ # TODO pass