mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] training fix the position of param casting when loading them (#8460)
fix the position of param casting when loading them
This commit is contained in:
@@ -1289,8 +1289,8 @@ def main(args):
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
@@ -1363,8 +1363,8 @@ def main(args):
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
Reference in New Issue
Block a user