1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

SD text-to-image torch compile compatible (#6519)

* added unwrapper

* fiz typo
This commit is contained in:
gzguevara
2024-01-12 04:58:35 +01:00
committed by GitHub
parent f486d34b04
commit 33d2b5b087

View File

@@ -46,6 +46,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
@@ -833,6 +834,12 @@ def main():
tracker_config.pop("validation_prompts")
accelerator.init_trackers(args.tracker_project_name, tracker_config)
# Function for unwrapping if model was compiled with `torch.compile`.
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -912,7 +919,7 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
@@ -927,7 +934,7 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
@@ -1023,7 +1030,7 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet = unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())