刘虹雨
update
8ed2f16
raw
history blame
458 Bytes
from mmcv import Registry
from DiT_VAE.diffusion.model.utils import set_grad_checkpoint
MODELS = Registry('vae')
def build_model(cfg, use_grad_checkpoint=False, use_fp32_attention=False, gc_step=1, **kwargs):
if isinstance(cfg, str):
cfg = dict(type=cfg)
model = MODELS.build(cfg, default_args=kwargs)
if use_grad_checkpoint:
set_grad_checkpoint(model, use_fp32_attention=use_fp32_attention, gc_step=gc_step)
return model