Source code for simulation.utils.machine_learning.cycle_gan.configs.base_options

from typing import List, Union

from torch import nn


[docs]class BaseOptions: activation: nn.Module = nn.Tanh() """Choose which activation to use.""" checkpoints_dir: str = "./checkpoints" """Models are saved here.""" conv_layers_in_block: int = 3 """Specify number of convolution layers per resnet block.""" crop_size: int = 512 """Then crop to this size.""" dilations: List[int] = [ 1, 2, 4, ] """Dilation for individual conv layers in every resnet block.""" epoch: Union[int, str] = "latest" """Which epoch to load? set to latest to use latest cached model """ init_gain: float = 0.02 """Scaling factor for normal, xavier and orthogonal.""" init_type: str = "normal" """Network initialization [normal | xavier | kaiming | orthogonal]""" input_nc: int = 1 """# of input image channels: 3 for RGB and 1 for grayscale""" lambda_idt_a: float = 0.5 """Weight for loss identity of domain A.""" lambda_idt_b: float = 0.5 """Weight for loss identity of domain B.""" lambda_cycle: float = 10 """Weight for cycle loss.""" load_size: int = 512 """Scale images to this size.""" mask: str = "resources/mask.png" """Path to a mask overlaid over all images.""" n_layers_d: int = 4 """Number of layers in the discriminator network.""" name: str = "dr_drift" """Name of the experiment. It decides where to store samples and models """ ndf: int = 32 """# of discriminator filters in the first conv layer""" netd: str = "basic" """Specify discriminator architecture. [basic | n_layers | no_patch]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator. """ netg: str = "resnet_9blocks" """Specify generator architecture [resnet_<ANY_INTEGER>blocks | unet_256 | unet_128]""" ngf: int = 32 """# of gen filters in the last conv layer""" no_dropout: bool = True """No dropout for the generator.""" norm: str = "instance" """Instance normalization or batch normalization [instance | batch | none]""" output_nc: int = 1 """Of output image channels: 3 for RGB and 1 for grayscale.""" preprocess: set = {"resize", "crop"} """Scaling and cropping of images at load time. [resize | crop | scale_width] """ verbose: bool = False """If specified, print more debugging information.""" cycle_noise_stddev: float = 0 """Standard deviation of noise added to the cycle input. Mean is 0. """ pool_size: int = 75 """The size of image buffer that stores previously generated images.""" max_dataset_size: int = 15000 """Maximum amount of images to load; -1 means infinity.""" is_wgan: bool = False """Decide whether to use wasserstein cycle gan or standard cycle gan.""" l1_or_l2_loss: str = "l1" """"l1" or "l2"; Decide whether to use l1 or l2 as cycle and identity loss functions.""" use_sigmoid: bool = True """Use sigmoid activation at end of discriminator."""
[docs] @classmethod def to_dict(cls) -> dict: return { k: v for cls in reversed(cls.__mro__) for k, v in cls.__dict__.items() if not k.startswith("__") }