1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fixes torch.compile() compatible training (#6589)

resolve conflicts
This commit is contained in:
Steve Rhoades
2024-01-16 18:17:03 -08:00
committed by GitHub
parent dd63168319
commit dce06680d2

View File

@@ -68,6 +68,7 @@ from diffusers.utils import (
is_wandb_available,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -1293,6 +1294,11 @@ def main(args):
else:
param.requires_grad = False
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
@@ -1303,14 +1309,14 @@ def main(args):
text_encoder_two_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
if args.train_text_encoder:
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
if args.train_text_encoder:
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
@@ -1338,11 +1344,11 @@ def main(args):
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
if isinstance(model, type(unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")