From 07eac4d65a8ec67e7ae971da4431f67095e9db8a Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 27 Nov 2023 04:00:40 -0800 Subject: [PATCH] Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (#5893) * Fix bug related to parsing unet_time_cond_proj_dim. * Fix analogous bug in the SD-XL LCM distillation script. --- .../train_lcm_distill_sd_wds.py | 11 ++++++++++- .../train_lcm_distill_sdxl_wds.py | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index ec4bf432f0..4c4ad984fc 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -657,6 +657,15 @@ def parse_args(): default=0.001, help="The huber loss parameter. Only used if `--loss_type=huber`.", ) + parser.add_argument( + "--unet_time_cond_proj_dim", + type=int, + default=256, + help=( + "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net" + " does not have `time_cond_proj_dim` set." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1138,7 +1147,7 @@ def main(args): # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min - w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim) + w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 7d2b1e1032..920950d0f6 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -677,6 +677,15 @@ def parse_args(): default=0.001, help="The huber loss parameter. Only used if `--loss_type=huber`.", ) + parser.add_argument( + "--unet_time_cond_proj_dim", + type=int, + default=256, + help=( + "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net" + " does not have `time_cond_proj_dim` set." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1233,6 +1242,7 @@ def main(args): # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) @@ -1243,7 +1253,7 @@ def main(args): noise_pred = unet( noisy_model_input, start_timesteps, - timestep_cond=None, + timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample @@ -1308,7 +1318,7 @@ def main(args): target_noise_pred = target_unet( x_prev.float(), timesteps, - timestep_cond=None, + timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample