mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix bug in LCM Distillation Scripts when args.unet_time_cond_proj_dim is used (#6523)
* Fix bug where unet's time_cond_proj_dim is not set correctly if using args.unet_time_cond_proj_dim. * make style
This commit is contained in:
@@ -921,10 +921,12 @@ def main(args):
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
time_cond_proj_dim = (
|
||||
teacher_unet.config.time_cond_proj_dim
|
||||
if teacher_unet.config.time_cond_proj_dim is not None
|
||||
else args.unet_time_cond_proj_dim
|
||||
)
|
||||
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
@@ -980,10 +980,12 @@ def main(args):
|
||||
|
||||
# 7. Create online student U-Net. This will be updated by the optimizer (e.g. via backpropagation.)
|
||||
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
|
||||
if teacher_unet.config.time_cond_proj_dim is None:
|
||||
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
|
||||
time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
|
||||
unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
time_cond_proj_dim = (
|
||||
teacher_unet.config.time_cond_proj_dim
|
||||
if teacher_unet.config.time_cond_proj_dim is not None
|
||||
else args.unet_time_cond_proj_dim
|
||||
)
|
||||
unet = UNet2DConditionModel.from_config(teacher_unet.config, time_cond_proj_dim=time_cond_proj_dim)
|
||||
# load teacher_unet weights into unet
|
||||
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
|
||||
unet.train()
|
||||
|
||||
Reference in New Issue
Block a user