From c97b709afa43c2a1b90bd3429ef113fd5848d675 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 2 Apr 2025 22:16:31 +0200 Subject: [PATCH] Add CacheMixin to Wan and LTX Transformers (#11187) * update * update * update --- src/diffusers/models/transformers/transformer_ltx.py | 3 ++- src/diffusers/models/transformers/transformer_wan.py | 3 ++- src/diffusers/pipelines/ltx/pipeline_ltx.py | 7 +++++++ src/diffusers/pipelines/ltx/pipeline_ltx_condition.py | 7 +++++++ src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py | 7 +++++++ 5 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index c1f2df5879..2ae2418098 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -26,6 +26,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -298,7 +299,7 @@ class LTXVideoTransformerBlock(nn.Module): @maybe_allow_in_graph -class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin): r""" A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 4eb4add376..aa03e97093 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -24,6 +24,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -288,7 +289,7 @@ class WanTransformerBlock(nn.Module): return hidden_states -class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in the Wan model. diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index f7b0811d1a..6f3faed8ff 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -489,6 +489,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def attention_kwargs(self): return self._attention_kwargs @@ -622,6 +626,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._interrupt = False + self._current_timestep = None # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -706,6 +711,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi if self.interrupt: continue + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = latent_model_input.to(prompt_embeds.dtype) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index e7f3666cb2..ef1fd56839 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -774,6 +774,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def attention_kwargs(self): return self._attention_kwargs @@ -933,6 +937,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._interrupt = False + self._current_timestep = None # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -1066,6 +1071,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL if self.interrupt: continue + self._current_timestep = t + if image_cond_noise_scale > 0: # Add timestep-dependent noise to the hard-conditioning latents # This helps with motion continuity, especially when conditioned on a single frame diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 0f640dc335..1ae67967c6 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -550,6 +550,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def attention_kwargs(self): return self._attention_kwargs @@ -686,6 +690,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs self._interrupt = False + self._current_timestep = None # 2. Define call parameters if prompt is not None and isinstance(prompt, str): @@ -778,6 +783,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo if self.interrupt: continue + self._current_timestep = t + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = latent_model_input.to(prompt_embeds.dtype)