From 33d2b5b08764d2579dd4ea555821b8986cd2dcf9 Mon Sep 17 00:00:00 2001 From: gzguevara <55751398+gzguevara@users.noreply.github.com> Date: Fri, 12 Jan 2024 04:58:35 +0100 Subject: [PATCH] SD text-to-image torch compile compatible (#6519) * added unwrapper * fiz typo --- examples/text_to_image/train_text_to_image.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 27d7bc5443..fe62c8133c 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module if is_wandb_available(): @@ -833,6 +834,12 @@ def main(): tracker_config.pop("validation_prompts") accelerator.init_trackers(args.tracker_project_name, tracker_config) + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -912,7 +919,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -927,7 +934,7 @@ def main(): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -1023,7 +1030,7 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) + unet = unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters())