diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 862c279cfa..a7195d3a67 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -906,7 +906,7 @@ class FluxPipeline( ) # 6. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -917,6 +917,7 @@ class FluxPipeline( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) + cc.mark_state("cond") noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -932,6 +933,8 @@ class FluxPipeline( if do_true_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + cc.mark_state("uncond") neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index e7f3666cb2..e3b49cb673 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1061,7 +1061,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL self._num_timesteps = len(timesteps) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -1090,6 +1090,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float() timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 6c4214fe1b..9ee96e6a39 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -771,7 +771,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo ) # 7. Denoising loop - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc: for i, t in enumerate(timesteps): if self.interrupt: continue @@ -783,6 +783,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + cc.mark_state("cond_uncond") noise_pred = self.transformer( hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds,