diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index ec4bf432f0..4c4ad984fc 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -657,6 +657,15 @@ def parse_args(): default=0.001, help="The huber loss parameter. Only used if `--loss_type=huber`.", ) + parser.add_argument( + "--unet_time_cond_proj_dim", + type=int, + default=256, + help=( + "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net" + " does not have `time_cond_proj_dim` set." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1138,7 +1147,7 @@ def main(args): # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min - w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim) + w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 7d2b1e1032..920950d0f6 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -677,6 +677,15 @@ def parse_args(): default=0.001, help="The huber loss parameter. Only used if `--loss_type=huber`.", ) + parser.add_argument( + "--unet_time_cond_proj_dim", + type=int, + default=256, + help=( + "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net" + " does not have `time_cond_proj_dim` set." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1233,6 +1242,7 @@ def main(args): # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) @@ -1243,7 +1253,7 @@ def main(args): noise_pred = unet( noisy_model_input, start_timesteps, - timestep_cond=None, + timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample @@ -1308,7 +1318,7 @@ def main(args): target_noise_pred = target_unet( x_prev.float(), timesteps, - timestep_cond=None, + timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample