Source code for simulation.utils.machine_learning.cycle_gan.configs.train_options

from typing import List

from .base_options import BaseOptions


[docs]class TrainOptions(BaseOptions): dataset_a: List[str] = [ "./../../../../data/real_images/beg_2019", ] """Path to images of domain A (real images). Can be a list of folders. """ dataset_b: List[str] = ["./../../../../data/simulated_images/random_roads"] """Path to images of domain B (simulated images). Can be a list of folders """ display_id: int = 1 """Window id of the web display.""" display_port: int = 8097 """Visdom port of the web display.""" is_train: bool = True """Enable or disable training mode.""" num_threads: int = 8 """# threads for loading data""" save_freq: int = 100 """Frequency of saving the current models.""" print_freq: int = 5 """Frequency of showing training results on console.""" beta1: float = 0.5 """Momentum term of adam.""" batch_size: int = 3 """Input batch size.""" lr: float = 0.0005 """Initial learning rate for adam.""" lr_decay_iters: int = 1 """Multiply by a gamma every lr_decay_iters iterations.""" lr_policy: str = "step" """Learning rate policy. [linear | step | plateau | cosine] """ lr_step_factor: float = 0.1 """Multiplication factor at every step in the step scheduler.""" n_epochs: int = 0 """Number of epochs with the initial learning rate.""" n_epochs_decay: int = 10 """Number of epochs to linearly decay learning rate to zero.""" no_flip: bool = False """Flip 50% of all training images vertically.""" continue_train: bool = False """Load checkpoints or start from scratch."""
[docs]class WassersteinCycleGANTrainOptions(TrainOptions): wgan_initial_n_critic: int = 1 """Number of iterations of the critic before starting training loop.""" wgan_clip_upper: float = 0.001 """Upper bound for weight clipping.""" wgan_clip_lower: float = -0.001 """Lower bound for weight clipping.""" wgan_n_critic: int = 5 """Number of iterations of the critic per generator iteration.""" is_wgan: bool = True """Decide whether to use wasserstein cycle gan or standard cycle gan."""
[docs]class CycleGANTrainOptions(TrainOptions): pass