From e5ebacb820d190eb17c4a20d84ebba52f85da712 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 28 Jan 2026 12:31:24 +0530 Subject: [PATCH] fix --- .../models/transformers/transformer_chroma.py | 3 ++- .../transformers/transformer_chronoedit.py | 3 ++- .../transformers/transformer_glm_image.py | 3 +-- .../transformers/transformer_hunyuan_video.py | 3 ++- .../transformers/transformer_hunyuan_video15.py | 3 ++- .../transformer_hunyuan_video_framepack.py | 3 ++- .../transformers/transformer_hunyuanimage.py | 3 ++- .../models/transformers/transformer_ltx.py | 3 ++- .../models/transformers/transformer_ltx2.py | 7 ++----- .../transformers/transformer_qwenimage.py | 17 ++--------------- .../transformers/transformer_skyreels_v2.py | 3 ++- .../models/transformers/transformer_wan.py | 3 ++- .../transformers/transformer_wan_animate.py | 3 ++- .../models/transformers/transformer_wan_vace.py | 3 ++- 14 files changed, 27 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 091ce9c66a..37b4e4e284 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -21,7 +21,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import deprecate, logging +from ...utils import apply_lora_scale, deprecate, logging from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward @@ -473,6 +473,7 @@ class ChromaTransformer2DModel( self.gradient_checkpointing = False + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_chronoedit.py b/src/diffusers/models/transformers/transformer_chronoedit.py index 3ef131f16b..8742cf2951 100644 --- a/src/diffusers/models/transformers/transformer_chronoedit.py +++ b/src/diffusers/models/transformers/transformer_chronoedit.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import deprecate, logging +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -638,6 +638,7 @@ class ChronoEditTransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index c12bd59c91..6f7ed2fca1 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import apply_lora_scale, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -595,7 +595,6 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach self.gradient_checkpointing = False - @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index a7211c82dd..84dcb1fe40 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -989,6 +989,7 @@ class HunyuanVideoTransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video15.py b/src/diffusers/models/transformers/transformer_hunyuan_video15.py index 67fa695688..8595595326 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video15.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video15.py @@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -620,6 +620,7 @@ class HunyuanVideo15Transformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py index e9b177206e..500cec89f8 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import get_logger +from ...utils import apply_lora_scale, get_logger from ..cache_utils import CacheMixin from ..embeddings import get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -198,6 +198,7 @@ class HunyuanVideoFramepackTransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_hunyuanimage.py b/src/diffusers/models/transformers/transformer_hunyuanimage.py index 9f52b10ba3..dc4b22c323 100644 --- a/src/diffusers/models/transformers/transformer_hunyuanimage.py +++ b/src/diffusers/models/transformers/transformer_hunyuanimage.py @@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -742,6 +742,7 @@ class HunyuanImageTransformer2DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py index 2928af7db3..4bb0eb9268 100644 --- a/src/diffusers/models/transformers/transformer_ltx.py +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -22,7 +22,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import deprecate, is_torch_version, logging +from ...utils import apply_lora_scale, deprecate, is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -491,6 +491,7 @@ class LTXVideoTransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index b67ace6f22..62bb0dfb1f 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -22,11 +22,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import ( - BaseOutput, - is_torch_version, - logging, -) +from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -1098,6 +1094,7 @@ class LTX2VideoTransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 771271ada4..b2d4bdcc8b 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, FeedForward @@ -829,6 +829,7 @@ class QwenImageTransformer2DModel( self.gradient_checkpointing = False self.zero_cond_t = zero_cond_t + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -887,20 +888,6 @@ class QwenImageTransformer2DModel( "The mask-based approach is more flexible and supports variable-length sequences.", standard_warn=False, ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.img_in(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index dac92bd6af..0a2be5a311 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import deprecate, logging +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -630,6 +630,7 @@ class SkyReelsV2Transformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 606de14f05..755a88dfda 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import deprecate, logging +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -622,6 +622,7 @@ class WanTransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py index 1780f6e219..c5c2d0052b 100644 --- a/src/diffusers/models/transformers/transformer_wan_animate.py +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin @@ -1140,6 +1140,7 @@ class WanAnimateTransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py index 80efbe1350..1c84b4628e 100644 --- a/src/diffusers/models/transformers/transformer_wan_vace.py +++ b/src/diffusers/models/transformers/transformer_wan_vace.py @@ -20,7 +20,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin from ..modeling_outputs import Transformer2DModelOutput @@ -261,6 +261,7 @@ class WanVACETransformer3DModel( self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor,