mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make text-to-image SDXL LoRA Training Script torch.compile compatible (#6556)
make compile compatible
This commit is contained in:
@@ -54,6 +54,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
@@ -460,13 +461,12 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
text_input_ids = text_input_ids_list[i]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -637,6 +637,11 @@ def main(args):
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -647,13 +652,13 @@ def main(args):
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
@@ -678,11 +683,11 @@ def main(args):
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
@@ -1031,8 +1036,12 @@ def main(args):
|
||||
)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds,
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if args.prediction_type is not None:
|
||||
@@ -1125,9 +1134,9 @@ def main(args):
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=unwrap_model(text_encoder_one),
|
||||
text_encoder_2=unwrap_model(text_encoder_two),
|
||||
unet=unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -1166,12 +1175,12 @@ def main(args):
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_two = unwrap_model(text_encoder_two)
|
||||
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
|
||||
|
||||
Reference in New Issue
Block a user