From 17cece072aec2007e7c3febb99455b66fd485af5 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Wed, 10 Jan 2024 18:51:07 -0800 Subject: [PATCH] Fix bug in LCM Distillation Scripts when args.unet_time_cond_proj_dim is used (#6523) * Fix bug where unet's time_cond_proj_dim is not set correctly if using args.unet_time_cond_proj_dim. * make style --- .../train_lcm_distill_sd_wds.py | 10 ++++++---- .../train_lcm_distill_sdxl_wds.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 3a07b3cf34..1f375201c6 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -921,10 +921,12 @@ def main(args): # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None - if teacher_unet.config.time_cond_proj_dim is None: - teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim - time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim - unet = UNet2DConditionModel(**teacher_unet.config) + time_cond_proj_dim = ( + teacher_unet.config.time_cond_proj_dim + if teacher_unet.config.time_cond_proj_dim is not None + else args.unet_time_cond_proj_dim + ) + unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim) # load teacher_unet weights into unet unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train() diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 5d2442a4e4..e3ca1c8a22 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -980,10 +980,12 @@ def main(args): # 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.) # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None - if teacher_unet.config.time_cond_proj_dim is None: - teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim - time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim - unet = UNet2DConditionModel(**teacher_unet.config) + time_cond_proj_dim = ( + teacher_unet.config.time_cond_proj_dim + if teacher_unet.config.time_cond_proj_dim is not None + else args.unet_time_cond_proj_dim + ) + unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim) # load teacher_unet weights into unet unet.load_state_dict(teacher_unet.state_dict(), strict=False) unet.train()