mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Make InstructPix2Pix SDXL Training Script torch.compile compatible (#6576)
* changes for pix2pix_sdxl * style fix
This commit is contained in:
@@ -52,6 +52,7 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_instru
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available, load_image
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
@@ -531,6 +532,11 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -1044,8 +1050,12 @@ def main():
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
model_pred = unet(
|
||||
concatenated_noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||
).sample
|
||||
concatenated_noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
@@ -1115,7 +1125,7 @@ def main():
|
||||
# The models need unwrapping because for compatibility in distributed training mode.
|
||||
pipeline = StableDiffusionXLInstructPix2PixPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
unet=unwrap_model(unet),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer_1,
|
||||
@@ -1177,7 +1187,7 @@ def main():
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unwrap_model(unet)
|
||||
if args.use_ema:
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user