3. Defining Neural Networks
In this part of the onboarding you will be introduced to a very simple neural network. You will find out how to use the network with your dataset.
3.1. Model
Pytorch makes defining models simple.
The following class is defined in kitcar_ml/onboarding/model.py.
It is just one fully connected layer and will perform poorly on the MNIST dataset.
1import torch
2import torch.nn
3import torch.nn.functional
4
5
6class OnboardingNet(torch.nn.Module):
7 """Neural network class.
8
9 The :func:`forward` defines how images are passed through the network.
10 """
11
12 def __init__(self):
13 super().__init__()
14 # Define fully connected layer
15 self.fc = torch.nn.Linear(28 * 28, 10)
16
17 def forward(self, x: torch.Tensor) -> torch.Tensor:
18 """Handle input.
19
20 Args:
21 x: Input image
22
23 Return:
24 Output class as 10-dimensional tensor.
25 """
26 x = torch.flatten(x, 1)
27 x = self.fc(x)
28 output = torch.nn.functional.softmax(x, dim=1)
29 return output
30
31 def load(self, path: str):
32 """Load model parameters from file.
33
34 Args:
35 path: Path to file
36 """
37 # TODO
38 pass
39
40 def save(self, path: str):
41 """Save model parameters to file.
42
43 Args:
44 path: Path to file
45 """
46 # TODO
47 pass
3.2. Testing
Before improving the actual model we need to take a closer look at how we can run the model on real data.
Therefore, you need to take a closer look at kitcar_ml/onboarding/script.py.
The test() function is already defined and works out of the box.
It expects the path to your label file and will then load the model and run it.
Therefore you can just execute:
python3 kitcar_ml/onboarding/script.py --test --test_label_file PATH/TO/YOUR/LABELS.yaml
You should receive an accuracy of around 10%.
Why? Because the model is not trained at all and just randomly guesses numbers. Remember: The dataset contains the numbers 0-9.
3.3. Training
Now we can run the model and know how to load a dataset. The missing piece is to actually train the model to predict useful numbers.
Your Task
Complete the train() function.
Tip
All of the parts you will need are already described in the comments within the function. There are some great tutorials on the internet that will help you figure out the basics of pytorch. E.g. https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-loss-function-and-optimizer.
Tip
You can use test() as a guideline for accessing the data and targets.
Similar to testing, you can run the training using the command:
python3 kitcar_ml/onboarding/script.py --train --train_label_file PATH/TO/YOUR/LABELS.yaml
You should start to see the model improve. However, it will not be perfect.
3.4. Improvements
The script.py and model.py are very basic so far. There are many improvements that can be made.
For the following tasks you will have to take a look at pytorch’s documentation or other online material.
The goal is really that you get comfortable working with pytorch, datasets and neural networks in general.
Further Tasks
Improve your model to get >95% accuracy. First look for layers that pytorch already provides. If you can’t improve your model, you can also find a lot of blog posts about mnist online.
Save the model after training & Load the model before testing. Complete
save()andload()inmodel.py.Calculate an average loss in your
test()function (optional)Use your GPU for training (optional)