Source code for kitcar_ml.utils.models.normalize
import torch
from torch import nn
from torchvision import transforms
[docs]class Normalize(nn.Module):
[docs] def forward(self, inputs: torch.Tensor, *args):
"""Normalize input tensor."""
normalized = []
for input in inputs:
# Calculate mean and std
mean, std = torch.std_mean(input, dim=(1, 2))
# Normalize image with previously calculated mean and std
normalized.append(transforms.Normalize(mean=mean, std=std)(input))
return normalized, *args