Source code for simulation.utils.machine_learning.cycle_gan.models.base_model

import abc
import itertools
import os
import pickle
from abc import ABC
from dataclasses import dataclass
from typing import List, Tuple

import torch
from kitcar_utils.basics.init_options import InitOptions
from torch import Tensor, nn
from torch.nn import L1Loss, MSELoss
from torch.optim import RMSprop

from simulation.utils.machine_learning.models import helper

from .cycle_gan_stats import CycleGANStats


[docs]@dataclass class CycleGANNetworks: """Container class for all networks used within the CycleGAN. The CycleGAN generally requires images from two domains a and b. It aims to translate images from one domain to the other. """ g_a_to_b: nn.Module """Generator that transforms images from domain a to domain b.""" g_b_to_a: nn.Module """Generator that transforms images from domain b to domain a.""" d_a: nn.Module = None """Discrimator that decides for images if they are real or fake in domain a.""" d_b: nn.Module = None """Discrimator that decides for images if they are real or fake in domain b."""
[docs] def save(self, prefix_path: str) -> None: """Save all the networks to the disk. Args: prefix_path (str): the path which gets extended by the model name """ for name, net in self.__dict__.items(): if net is None: continue net = pickle.loads(pickle.dumps(net)) save_path = prefix_path + f"{name}.pth" torch.save(net.state_dict(), save_path)
[docs] def load(self, prefix_path: str, device: torch.device): """Load all the networks from the disk. Args: prefix_path (str): the path which is extended by the model name device (torch.device): The device on which the networks are loaded """ for name, net in self.__dict__.items(): if net is None: continue load_path = prefix_path + f"{name}.pth" if not os.path.isfile(load_path): raise FileNotFoundError(f"No model weights file found at {load_path}") if isinstance(net, torch.nn.DataParallel): net = net.module # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on device state_dict = torch.load(load_path, map_location=str(device)) print(f"Loaded: {load_path}") if hasattr(state_dict, "_metadata"): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 for key in list( state_dict.keys() ): # need to copy keys here because we mutate in loop CycleGANNetworks.__patch_instance_norm_state_dict( state_dict, net, key.split(".") ) net.load_state_dict(state_dict)
@staticmethod def __patch_instance_norm_state_dict( state_dict: dict, module: nn.Module, keys: List[str], i: int = 0 ) -> None: """Fix InstanceNorm checkpoints incompatibility (prior to 0.4) Args: state_dict (dict): a dict containing parameters from the saved model files module (nn.Module): the network loaded from a file keys (List[int]): the keys inside the save file i (int): current index in network structure """ key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith("InstanceNorm") and ( key == "running_mean" or key == "running_var" ): if getattr(module, key) is None: state_dict.pop(".".join(keys)) if module.__class__.__name__.startswith("InstanceNorm") and ( key == "num_batches_tracked" ): state_dict.pop(".".join(keys)) else: CycleGANNetworks.__patch_instance_norm_state_dict( state_dict, getattr(module, key), keys, i + 1 )
[docs] def print(self, verbose: bool) -> None: """Print the total number of parameters in the network and (if verbose) network architecture. Args: verbose (bool): print the network architecture """ print("---------- Networks initialized -------------") for name, net in self.__dict__.items(): if net is None: continue num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print( "[Network %s] Total number of parameters : %.3f M" % (name, num_params / 1e6) ) print("-----------------------------------------------")
def __iter__(self): return (net for net in self.__dict__.values() if net is not None)
[docs]class BaseModel(ABC, InitOptions): def __init__( self, netg_a_to_b, netg_b_to_a, netd_a, netd_b, is_train, lambda_cycle, lambda_idt_a, lambda_idt_b, is_l1, optimizer_type, lr_policy, beta1: float = 0.5, lr: float = 0.0002, cycle_noise_stddev: float = 0, ): self.is_train = is_train self.lambda_cycle = lambda_cycle self.lambda_idt_a = lambda_idt_a self.lambda_idt_b = lambda_idt_b self.is_l1 = is_l1 self.metric = 0 # used for learning rate policy 'plateau' self.lr_policy = lr_policy self.cycle_noise_stddev = cycle_noise_stddev if is_train else 0 self.networks = CycleGANNetworks(netg_a_to_b, netg_b_to_a, netd_a, netd_b) if self.is_train: # define loss functions self.criterionCycle = L1Loss() if self.is_l1 else MSELoss() self.criterionIdt = L1Loss() if self.is_l1 else MSELoss() if optimizer_type == "rms_prop": self.optimizer_g = RMSprop( itertools.chain( self.networks.g_a_to_b.parameters(), self.networks.g_b_to_a.parameters(), ), lr=lr, ) self.optimizer_d = RMSprop( itertools.chain( self.networks.d_a.parameters(), self.networks.d_b.parameters() ), lr=lr, ) else: self.optimizer_g = torch.optim.Adam( itertools.chain( self.networks.g_a_to_b.parameters(), self.networks.g_b_to_a.parameters(), ), lr=lr, betas=(beta1, 0.999), ) self.optimizer_d = torch.optim.Adam( itertools.chain( self.networks.d_a.parameters(), self.networks.d_b.parameters() ), lr=lr, betas=(beta1, 0.999), )
[docs] def forward(self, real_a, real_b) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Run forward pass; called by both functions <optimize_parameters> and <test>.""" fake_b = self.networks.g_a_to_b(real_a) # G_A(A) fake_a = self.networks.g_b_to_a(real_b) # G_B(B) # Calculate cycle. Add gaussian if self.cycle_noise_stddev is not 0 # See: https://discuss.pytorch.org/t/writing-a-simple-gaussian-noise-layer-in-pytorch/4694 # noqa: E501 # There are two individual noise terms because fake_A and fake_B may # have different dimensions # (At end of dataset were one of them is not a full batch for example) if self.cycle_noise_stddev == 0: noise_a = 0 noise_b = 0 else: noise_a = ( torch.zeros(fake_a.size()) .normal_(0, self.cycle_noise_stddev) .requires_grad_() ) noise_b = ( torch.zeros(fake_b.size()) .normal_(0, self.cycle_noise_stddev) .requires_grad_() ) rec_a = self.networks.g_b_to_a(fake_b + noise_b) rec_b = self.networks.g_a_to_b(fake_a + noise_a) return fake_a, fake_b, rec_a, rec_b
[docs] def test(self, batch_a, batch_b) -> CycleGANStats: """Forward function used in test time. This function wraps <forward> function in no_grad() so we don't save intermediate steps for backpropagation It also calls <compute_visuals> to produce additional visualization results """ with torch.no_grad(): fake_a, fake_b, rec_a, rec_b = self.forward(batch_a, batch_b) return CycleGANStats(batch_a, batch_b, fake_a, fake_b, rec_a, rec_b)
[docs] def create_schedulers( self, lr_policy: str = "linear", lr_decay_iters: int = 50, lr_step_factor: float = 0.1, n_epochs: int = 100, ): """Create schedulers. Args: lr_policy: learning rate policy. [linear | step | plateau | cosine] lr_decay_iters: multiply by a gamma every lr_decay_iters iterations lr_step_factor: multiply lr with this factor every epoch n_epochs: number of epochs with the initial learning rate """ self.schedulers = [ helper.get_scheduler( optimizer, lr_policy, lr_decay_iters, n_epochs, lr_step_factor ) for optimizer in [self.optimizer_d, self.optimizer_g] ]
[docs] def eval(self) -> None: """Make models eval mode during test time.""" for net in self.networks: net.eval()
[docs] def update_learning_rate(self) -> None: """Update learning rates for all the networks.""" old_lr = self.optimizer_g.param_groups[0]["lr"] for scheduler in self.schedulers: if self.lr_policy == "plateau": scheduler.step(self.metric) else: scheduler.step() lr = self.optimizer_g.param_groups[0]["lr"] print(f"learning rate {old_lr:.7f} -> {lr:.7f}")
[docs] @abc.abstractmethod def do_iteration( self, batch_a: Tuple[torch.Tensor, str], batch_b: Tuple[torch.Tensor, str] ): raise NotImplementedError("Abstract method!")
[docs] def pre_training(self): pass