1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Wuerstchen] fix fp16 training and correct lora args (#6245)

fix fp16 training

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Kashif Rasul
2023-12-26 11:40:04 +01:00
committed by sayakpaul
parent 4c7e983bb5
commit 0bb9cf0216

View File

@@ -527,9 +527,17 @@ def main():
# lora attn processor
prior_lora_config = LoraConfig(
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
r=args.rank,
lora_alpha=args.rank,
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
# Add adapter and make sure the trainable params are in float32.
prior.add_adapter(prior_lora_config)
if args.mixed_precision == "fp16":
for param in prior.parameters():
# only upcast trainable parameters (LoRA) into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):