1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update train_text_to_image_lora.py (#2767)

* Update train_text_to_image_lora.py

* Update train_text_to_image_lora.py

* Update train_text_to_image_lora.py

* Update train_text_to_image_lora.py

* format
This commit is contained in:
Haofan Wang
2023-03-23 21:28:47 +08:00
committed by GitHub
parent 0d7aac3e8d
commit dc5b4e2342

View File

@@ -582,7 +582,7 @@ def main():
else:
optimizer_cls = torch.optim.AdamW
if args.peft:
if args.use_peft:
# Optimizer creation
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters())
@@ -724,7 +724,7 @@ def main():
)
# Prepare everything with our `accelerator`.
if args.peft:
if args.use_peft:
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
@@ -842,7 +842,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
if args.peft:
if args.use_peft:
params_to_clip = (
itertools.chain(unet.parameters(), text_encoder.parameters())
if args.train_text_encoder
@@ -922,18 +922,22 @@ def main():
if accelerator.is_main_process:
if args.use_peft:
lora_config = {}
state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True)
unwarpped_unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
if args.train_text_encoder:
unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(
text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
)
text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
state_dict.update(text_encoder_state_dict)
lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True)
lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
inference=True
)
accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt"))
with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f:
accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
json.dump(lora_config, f)
else:
unet = unet.to(torch.float32)
@@ -957,12 +961,12 @@ def main():
if args.use_peft:
def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype):
with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f:
def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
lora_config = json.load(f)
print(lora_config)
checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt"
checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
lora_checkpoint_sd = torch.load(checkpoint)
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
text_encoder_lora_ds = {
@@ -985,9 +989,7 @@ def main():
pipe.to(device)
return pipe
pipeline = load_and_set_lora_ckpt(
pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype
)
pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
else:
pipeline = pipeline.to(accelerator.device)
@@ -995,7 +997,10 @@ def main():
pipeline.unet.load_attn_procs(args.output_dir)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
else:
generator = None
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])