EchoFlow / lifm /UNet-S-16f8 /config.yaml
HReynaud's picture
training configs
514f603
raw
history blame
3.52 kB
globals:
target_fps: original
target_nframes: 64
outputs:
- image
- view
resolution: 112
latent_res: 14
latent_channels: 16
denoiser:
target: echosyn.common.models.SegUnet2DModel
args:
sample_size: 28
in_channels: 17
out_channels: 16
center_input_sample: false
time_embedding_type: positional
freq_shift: 0
flip_sin_to_cos: true
down_block_types:
- AttnDownBlock2D
- AttnDownBlock2D
- AttnDownBlock2D
- DownBlock2D
up_block_types:
- UpBlock2D
- AttnUpBlock2D
- AttnUpBlock2D
- AttnUpBlock2D
block_out_channels:
- 96
- 192
- 288
- 384
layers_per_block: 2
mid_block_scale_factor: 1
downsample_padding: 1
downsample_type: resnet
upsample_type: resnet
dropout: 0.0
act_fn: silu
attention_head_dim: 8
norm_num_groups: 32
attn_norm_num_groups: null
norm_eps: 1.0e-05
resnet_time_scale_shift: default
class_embed_type: timestep
num_class_embeds: null
optimizer:
target: torch.optim.AdamW
args:
lr: 5.0e-05
betas:
- 0.9
- 0.999
weight_decay: 0.01
eps: 1.0e-08
scheduler:
target: echosyn.common.schedulers.StepBasedLearningRateScheduleWithWarmup
args:
warmup_steps: 5000
ref_steps: ${max_train_steps}
eta_min: 1.0e-06
decay_rate: 2
vae:
target: diffusers.AutoencoderKL
pretrained: vae/avae-16f8
datasets:
- name: LatentSeg
active: true
params:
root: avae-16f8/dynamic
outputs: ${globals.outputs}
target_fps: ${globals.target_fps}
view_label: A4C
target_nframes: ${globals.target_nframes}
latent_channels: ${globals.latent_channels}
segmentation_root: segmentations/dynamic
target_resolution: ${globals.latent_res}
- name: LatentSeg
active: true
params:
root: avae-16f8/ped_a4c
outputs: ${globals.outputs}
target_fps: ${globals.target_fps}
view_label: A4C
target_nframes: ${globals.target_nframes}
latent_channels: ${globals.latent_channels}
segmentation_root: segmentations/ped_a4c
target_resolution: ${globals.latent_res}
- name: LatentSeg
active: true
params:
root: avae-16f8/ped_psax
outputs: ${globals.outputs}
target_fps: ${globals.target_fps}
view_label: PSAX
target_nframes: ${globals.target_nframes}
latent_channels: ${globals.latent_channels}
segmentation_root: segmentations/ped_psax
target_resolution: ${globals.latent_res}
- name: LatentSeg
active: true
params:
root: avae-16f8/lvh
outputs: ${globals.outputs}
target_fps: ${globals.target_fps}
view_label: PLAX
target_nframes: ${globals.target_nframes}
latent_channels: ${globals.latent_channels}
segmentation_root: no_seg
target_resolution: ${globals.latent_res}
dataloader:
target: torch.utils.data.DataLoader
args:
shuffle: true
batch_size: 128
num_workers: 16
pin_memory: true
drop_last: true
persistent_workers: true
max_train_steps: 1000000
gradient_accumulation_steps: 1
mixed_precision: bf16
use_ema: true
noise_offset: 0.1
max_grad_norm: 1.0
max_grad_value: -1
pad_latents: false
sample_latents: true
output_dir: experiments/${wandb_args.name}
logging_dir: logs
report_to: wandb
wandb_args:
project: EchoFlow
name: UNet-S-16f8
group: UNet
checkpointing_steps: 10000
checkpoints_to_keep:
- 50000
- 100000
- 200000
- 500000
- 1000000
resume_from_checkpoint: latest
validation:
samples: 4
steps: 5000
method: euler
timesteps: 25
seed: 42
num_train_epochs: 45455