mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update train_text_to_image_lora.py (#2767)
* Update train_text_to_image_lora.py * Update train_text_to_image_lora.py * Update train_text_to_image_lora.py * Update train_text_to_image_lora.py * format
This commit is contained in:
@@ -582,7 +582,7 @@ def main():
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
if args.peft:
|
||||
if args.use_peft:
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
@@ -724,7 +724,7 @@ def main():
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
if args.peft:
|
||||
if args.use_peft:
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
@@ -842,7 +842,7 @@ def main():
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
if args.peft:
|
||||
if args.use_peft:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
@@ -922,18 +922,22 @@ def main():
|
||||
if accelerator.is_main_process:
|
||||
if args.use_peft:
|
||||
lora_config = {}
|
||||
state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet))
|
||||
lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True)
|
||||
unwarpped_unet = accelerator.unwrap_model(unet)
|
||||
state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
|
||||
lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
|
||||
if args.train_text_encoder:
|
||||
unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
text_encoder_state_dict = get_peft_model_state_dict(
|
||||
text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
|
||||
unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
|
||||
)
|
||||
text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
|
||||
state_dict.update(text_encoder_state_dict)
|
||||
lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True)
|
||||
lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
|
||||
inference=True
|
||||
)
|
||||
|
||||
accelerator.save(state_dict, os.path.join(args.output_dir, f"{args.instance_prompt}_lora.pt"))
|
||||
with open(os.path.join(args.output_dir, f"{args.instance_prompt}_lora_config.json"), "w") as f:
|
||||
accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
|
||||
with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
|
||||
json.dump(lora_config, f)
|
||||
else:
|
||||
unet = unet.to(torch.float32)
|
||||
@@ -957,12 +961,12 @@ def main():
|
||||
|
||||
if args.use_peft:
|
||||
|
||||
def load_and_set_lora_ckpt(pipe, ckpt_dir, instance_prompt, device, dtype):
|
||||
with open(f"{ckpt_dir}{instance_prompt}_lora_config.json", "r") as f:
|
||||
def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
|
||||
with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
|
||||
lora_config = json.load(f)
|
||||
print(lora_config)
|
||||
|
||||
checkpoint = f"{ckpt_dir}{instance_prompt}_lora.pt"
|
||||
checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
|
||||
lora_checkpoint_sd = torch.load(checkpoint)
|
||||
unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
|
||||
text_encoder_lora_ds = {
|
||||
@@ -985,9 +989,7 @@ def main():
|
||||
pipe.to(device)
|
||||
return pipe
|
||||
|
||||
pipeline = load_and_set_lora_ckpt(
|
||||
pipeline, args.output_dir, args.instance_prompt, accelerator.device, weight_dtype
|
||||
)
|
||||
pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
|
||||
|
||||
else:
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
@@ -995,7 +997,10 @@ def main():
|
||||
pipeline.unet.load_attn_procs(args.output_dir)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
if args.seed is not None:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
else:
|
||||
generator = None
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
|
||||
|
||||
Reference in New Issue
Block a user