import abc
import itertools
import os
import pickle
from abc import ABC
from dataclasses import dataclass
from typing import List, Tuple
import torch
from kitcar_utils.basics.init_options import InitOptions
from torch import Tensor, nn
from torch.nn import L1Loss, MSELoss
from torch.optim import RMSprop
from simulation.utils.machine_learning.models import helper
from .cycle_gan_stats import CycleGANStats
[docs]@dataclass
class CycleGANNetworks:
"""Container class for all networks used within the CycleGAN.
The CycleGAN generally requires images from two domains a and b. It aims to translate
images from one domain to the other.
"""
g_a_to_b: nn.Module
"""Generator that transforms images from domain a to domain b."""
g_b_to_a: nn.Module
"""Generator that transforms images from domain b to domain a."""
d_a: nn.Module = None
"""Discrimator that decides for images if they are real or fake in domain a."""
d_b: nn.Module = None
"""Discrimator that decides for images if they are real or fake in domain b."""
[docs] def save(self, prefix_path: str) -> None:
"""Save all the networks to the disk.
Args:
prefix_path (str): the path which gets extended by the model name
"""
for name, net in self.__dict__.items():
if net is None:
continue
net = pickle.loads(pickle.dumps(net))
save_path = prefix_path + f"{name}.pth"
torch.save(net.state_dict(), save_path)
[docs] def load(self, prefix_path: str, device: torch.device):
"""Load all the networks from the disk.
Args:
prefix_path (str): the path which is extended by the model name
device (torch.device): The device on which the networks are loaded
"""
for name, net in self.__dict__.items():
if net is None:
continue
load_path = prefix_path + f"{name}.pth"
if not os.path.isfile(load_path):
raise FileNotFoundError(f"No model weights file found at {load_path}")
if isinstance(net, torch.nn.DataParallel):
net = net.module
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on device
state_dict = torch.load(load_path, map_location=str(device))
print(f"Loaded: {load_path}")
if hasattr(state_dict, "_metadata"):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(
state_dict.keys()
): # need to copy keys here because we mutate in loop
CycleGANNetworks.__patch_instance_norm_state_dict(
state_dict, net, key.split(".")
)
net.load_state_dict(state_dict)
@staticmethod
def __patch_instance_norm_state_dict(
state_dict: dict, module: nn.Module, keys: List[str], i: int = 0
) -> None:
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)
Args:
state_dict (dict): a dict containing parameters from the saved model
files
module (nn.Module): the network loaded from a file
keys (List[int]): the keys inside the save file
i (int): current index in network structure
"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith("InstanceNorm") and (
key == "running_mean" or key == "running_var"
):
if getattr(module, key) is None:
state_dict.pop(".".join(keys))
if module.__class__.__name__.startswith("InstanceNorm") and (
key == "num_batches_tracked"
):
state_dict.pop(".".join(keys))
else:
CycleGANNetworks.__patch_instance_norm_state_dict(
state_dict, getattr(module, key), keys, i + 1
)
[docs] def print(self, verbose: bool) -> None:
"""Print the total number of parameters in the network and (if verbose) network
architecture.
Args:
verbose (bool): print the network architecture
"""
print("---------- Networks initialized -------------")
for name, net in self.__dict__.items():
if net is None:
continue
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print(
"[Network %s] Total number of parameters : %.3f M"
% (name, num_params / 1e6)
)
print("-----------------------------------------------")
def __iter__(self):
return (net for net in self.__dict__.values() if net is not None)
[docs]class BaseModel(ABC, InitOptions):
def __init__(
self,
netg_a_to_b,
netg_b_to_a,
netd_a,
netd_b,
is_train,
lambda_cycle,
lambda_idt_a,
lambda_idt_b,
is_l1,
optimizer_type,
lr_policy,
beta1: float = 0.5,
lr: float = 0.0002,
cycle_noise_stddev: float = 0,
):
self.is_train = is_train
self.lambda_cycle = lambda_cycle
self.lambda_idt_a = lambda_idt_a
self.lambda_idt_b = lambda_idt_b
self.is_l1 = is_l1
self.metric = 0 # used for learning rate policy 'plateau'
self.lr_policy = lr_policy
self.cycle_noise_stddev = cycle_noise_stddev if is_train else 0
self.networks = CycleGANNetworks(netg_a_to_b, netg_b_to_a, netd_a, netd_b)
if self.is_train:
# define loss functions
self.criterionCycle = L1Loss() if self.is_l1 else MSELoss()
self.criterionIdt = L1Loss() if self.is_l1 else MSELoss()
if optimizer_type == "rms_prop":
self.optimizer_g = RMSprop(
itertools.chain(
self.networks.g_a_to_b.parameters(),
self.networks.g_b_to_a.parameters(),
),
lr=lr,
)
self.optimizer_d = RMSprop(
itertools.chain(
self.networks.d_a.parameters(), self.networks.d_b.parameters()
),
lr=lr,
)
else:
self.optimizer_g = torch.optim.Adam(
itertools.chain(
self.networks.g_a_to_b.parameters(),
self.networks.g_b_to_a.parameters(),
),
lr=lr,
betas=(beta1, 0.999),
)
self.optimizer_d = torch.optim.Adam(
itertools.chain(
self.networks.d_a.parameters(), self.networks.d_b.parameters()
),
lr=lr,
betas=(beta1, 0.999),
)
[docs] def forward(self, real_a, real_b) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
fake_b = self.networks.g_a_to_b(real_a) # G_A(A)
fake_a = self.networks.g_b_to_a(real_b) # G_B(B)
# Calculate cycle. Add gaussian if self.cycle_noise_stddev is not 0
# See: https://discuss.pytorch.org/t/writing-a-simple-gaussian-noise-layer-in-pytorch/4694 # noqa: E501
# There are two individual noise terms because fake_A and fake_B may
# have different dimensions
# (At end of dataset were one of them is not a full batch for example)
if self.cycle_noise_stddev == 0:
noise_a = 0
noise_b = 0
else:
noise_a = (
torch.zeros(fake_a.size())
.normal_(0, self.cycle_noise_stddev)
.requires_grad_()
)
noise_b = (
torch.zeros(fake_b.size())
.normal_(0, self.cycle_noise_stddev)
.requires_grad_()
)
rec_a = self.networks.g_b_to_a(fake_b + noise_b)
rec_b = self.networks.g_a_to_b(fake_a + noise_a)
return fake_a, fake_b, rec_a, rec_b
[docs] def test(self, batch_a, batch_b) -> CycleGANStats:
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate
steps for backpropagation It also calls <compute_visuals> to produce additional
visualization results
"""
with torch.no_grad():
fake_a, fake_b, rec_a, rec_b = self.forward(batch_a, batch_b)
return CycleGANStats(batch_a, batch_b, fake_a, fake_b, rec_a, rec_b)
[docs] def create_schedulers(
self,
lr_policy: str = "linear",
lr_decay_iters: int = 50,
lr_step_factor: float = 0.1,
n_epochs: int = 100,
):
"""Create schedulers.
Args:
lr_policy: learning rate policy. [linear | step | plateau | cosine]
lr_decay_iters: multiply by a gamma every lr_decay_iters iterations
lr_step_factor: multiply lr with this factor every epoch
n_epochs: number of epochs with the initial learning rate
"""
self.schedulers = [
helper.get_scheduler(
optimizer, lr_policy, lr_decay_iters, n_epochs, lr_step_factor
)
for optimizer in [self.optimizer_d, self.optimizer_g]
]
[docs] def eval(self) -> None:
"""Make models eval mode during test time."""
for net in self.networks:
net.eval()
[docs] def update_learning_rate(self) -> None:
"""Update learning rates for all the networks."""
old_lr = self.optimizer_g.param_groups[0]["lr"]
for scheduler in self.schedulers:
if self.lr_policy == "plateau":
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizer_g.param_groups[0]["lr"]
print(f"learning rate {old_lr:.7f} -> {lr:.7f}")
[docs] @abc.abstractmethod
def do_iteration(
self, batch_a: Tuple[torch.Tensor, str], batch_b: Tuple[torch.Tensor, str]
):
raise NotImplementedError("Abstract method!")
[docs] def pre_training(self):
pass