import torch
import torch.nn
import torch.nn.functional
[docs]class OnboardingNet(torch.nn.Module):
"""Neural network class.
The :func:`forward` defines how images are passed through the network.
"""
def __init__(self):
super().__init__()
# Define fully connected layer
self.fc = torch.nn.Linear(28 * 28, 10)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Handle input.
Args:
x: Input image
Return:
Output class as 10-dimensional tensor.
"""
x = torch.flatten(x, 1)
x = self.fc(x)
output = torch.nn.functional.softmax(x, dim=1)
return output
[docs] def load(self, path: str):
"""Load model parameters from file.
Args:
path: Path to file
"""
# TODO
pass
[docs] def save(self, path: str):
"""Save model parameters to file.
Args:
path: Path to file
"""
# TODO
pass