mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Wuerstchen] fix fp16 training and correct lora args (#6245)
fix fp16 training Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -527,9 +527,17 @@ def main():
|
||||
|
||||
# lora attn processor
|
||||
prior_lora_config = LoraConfig(
|
||||
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
prior.add_adapter(prior_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in prior.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
|
||||
Reference in New Issue
Block a user