From 0bee4d336b925b6064eee156f5a155e3ca3b30ab Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Thu, 11 Apr 2024 10:52:12 -0700 Subject: [PATCH] LCM Distill Scripts Fix Bug when Initializing Target U-Net (#6848) * Initialize target_unet from unet rather than teacher_unet so that we correctly add time_embedding.cond_proj if necessary. * Use UNet2DConditionModel.from_config to initialize target_unet from unet's config. --------- Co-authored-by: Sayak Paul --- examples/consistency_distillation/train_lcm_distill_sd_wds.py | 2 +- examples/consistency_distillation/train_lcm_distill_sdxl_wds.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index d873cb8deb..5dcad9f6cc 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -945,7 +945,7 @@ def main(args): # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). # Initialize from (online) unet - target_unet = UNet2DConditionModel(**teacher_unet.config) + target_unet = UNet2DConditionModel.from_config(unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() target_unet.requires_grad_(False) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 862777411c..a7deca72a8 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -1004,7 +1004,7 @@ def main(args): # 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging). # Initialize from (online) unet - target_unet = UNet2DConditionModel(**teacher_unet.config) + target_unet = UNet2DConditionModel.from_config(unet.config) target_unet.load_state_dict(unet.state_dict()) target_unet.train() target_unet.requires_grad_(False)