Source code for kitcar_ml.utils.models.normalize

import torch
from torch import nn
from torchvision import transforms


[docs]class Normalize(nn.Module):
[docs] def forward(self, inputs: torch.Tensor, *args): """Normalize input tensor.""" normalized = [] for input in inputs: # Calculate mean and std mean, std = torch.std_mean(input, dim=(1, 2)) # Normalize image with previously calculated mean and std normalized.append(transforms.Normalize(mean=mean, std=std)(input)) return normalized, *args