mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
LCM Distill Scripts Fix Bug when Initializing Target U-Net (#6848)
* Initialize target_unet from unet rather than teacher_unet so that we correctly add time_embedding.cond_proj if necessary. * Use UNet2DConditionModel.from_config to initialize target_unet from unet's config. --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -945,7 +945,7 @@ def main(args):
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet = UNet2DConditionModel.from_config(unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
target_unet.requires_grad_(False)
|
||||
|
||||
@@ -1004,7 +1004,7 @@ def main(args):
|
||||
|
||||
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
|
||||
# Initialize from (online) unet
|
||||
target_unet = UNet2DConditionModel(**teacher_unet.config)
|
||||
target_unet = UNet2DConditionModel.from_config(unet.config)
|
||||
target_unet.load_state_dict(unet.state_dict())
|
||||
target_unet.train()
|
||||
target_unet.requires_grad_(False)
|
||||
|
||||
Reference in New Issue
Block a user