From 05faf3263b9acceffe1ec3dc3be6551703806cb6 Mon Sep 17 00:00:00 2001 From: gzguevara <55751398+gzguevara@users.noreply.github.com> Date: Mon, 15 Jan 2024 12:19:11 +0100 Subject: [PATCH] SDXL text-to-image torch compatible (#6550) * torch compatible * code quality fix * ruff style * ruff format --- .../text_to_image/train_text_to_image_sdxl.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 0bb57b1f31..2c538c65bc 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -44,16 +44,12 @@ from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - StableDiffusionXLPipeline, - UNet2DConditionModel, -) +from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -508,11 +504,12 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), output_hidden_states=True, + return_dict=False, ) # We are only ALWAYS interested in the pooled output of the final text encoder pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds[-1][-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) @@ -955,6 +952,12 @@ def main(args): if accelerator.is_main_process: accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) + # Function for unwraping if torch.compile() was used in accelerate. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1054,8 +1057,12 @@ def main(args): pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet( - noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions - ).sample + noisy_model_input, + timesteps, + prompt_embeds, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -1206,7 +1213,7 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters())