mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (#5893)
* Fix bug related to parsing unet_time_cond_proj_dim. * Fix analogous bug in the SD-XL LCM distillation script.
This commit is contained in:
@@ -657,6 +657,15 @@ def parse_args():
|
||||
default=0.001,
|
||||
help="The huber loss parameter. Only used if `--loss_type=huber`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_time_cond_proj_dim",
|
||||
type=int,
|
||||
default=256,
|
||||
help=(
|
||||
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -1138,7 +1147,7 @@ def main(args):
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
# Move to U-Net device and dtype
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
@@ -677,6 +677,15 @@ def parse_args():
|
||||
default=0.001,
|
||||
help="The huber loss parameter. Only used if `--loss_type=huber`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unet_time_cond_proj_dim",
|
||||
type=int,
|
||||
default=256,
|
||||
help=(
|
||||
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
|
||||
" does not have `time_cond_proj_dim` set."
|
||||
),
|
||||
)
|
||||
# ----Exponential Moving Average (EMA)----
|
||||
parser.add_argument(
|
||||
"--ema_decay",
|
||||
@@ -1233,6 +1242,7 @@ def main(args):
|
||||
|
||||
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
|
||||
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
|
||||
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
|
||||
w = w.reshape(bsz, 1, 1, 1)
|
||||
w = w.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
@@ -1243,7 +1253,7 @@ def main(args):
|
||||
noise_pred = unet(
|
||||
noisy_model_input,
|
||||
start_timesteps,
|
||||
timestep_cond=None,
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
@@ -1308,7 +1318,7 @@ def main(args):
|
||||
target_noise_pred = target_unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
timestep_cond=None,
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=prompt_embeds.float(),
|
||||
added_cond_kwargs=encoded_text,
|
||||
).sample
|
||||
|
||||
Reference in New Issue
Block a user