From 495fddb8aec9d5c41ca5b45ce86c0c6301f17e16 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 15 May 2025 21:42:19 +0200 Subject: [PATCH] support mochi --- src/diffusers/pipelines/mochi/pipeline_mochi.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index d1f88b02c5..dc83d3bc1d 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] # Mochi CFG + Sampling runs in FP32 noise_pred = noise_pred.to(torch.float32)