diff --git a/src/diffusers/models/controlnets/controlnet.py b/src/diffusers/models/controlnets/controlnet.py index 0b5b9fa3ef..c25bd2b986 100644 --- a/src/diffusers/models/controlnets/controlnet.py +++ b/src/diffusers/models/controlnets/controlnet.py @@ -21,7 +21,7 @@ from torch.nn import functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import BaseOutput, logging +from ...utils import BaseOutput, apply_lora_scale, logging from ..attention import AttentionMixin from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -598,6 +598,7 @@ class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModel for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) + @apply_lora_scale("cross_attention_kwargs") def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index 639a8ad739..e1273c2d34 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -20,7 +20,14 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + apply_lora_scale, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ..attention import AttentionMixin from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed @@ -150,6 +157,7 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi return controlnet + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index fa374285ee..b7b7abbff8 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -20,7 +20,15 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + apply_lora_scale, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ..attention import AttentionMixin from ..cache_utils import CacheMixin from ..controlnets.controlnet import zero_module @@ -123,6 +131,7 @@ class QwenImageControlNetModel( return controlnet + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py index c71a8b3266..a739f12080 100644 --- a/src/diffusers/models/controlnets/controlnet_sana.py +++ b/src/diffusers/models/controlnets/controlnet_sana.py @@ -20,7 +20,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ...utils import BaseOutput, apply_lora_scale, logging from ..attention import AttentionMixin from ..embeddings import PatchEmbed, PixArtAlphaTextProjection from ..modeling_outputs import Transformer2DModelOutput @@ -117,6 +117,7 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -129,21 +130,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -218,10 +204,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi block_res_sample = controlnet_block(block_res_sample) controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] if not return_dict: diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 08b86ff344..db758422cb 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -21,7 +21,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, apply_lora_scale, logging, scale_lora_layers, unscale_lora_layers from ..attention import AttentionMixin, JointTransformerBlock from ..attention_processor import Attention, FusedJointAttnProcessor2_0 from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed @@ -269,6 +269,7 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix return controlnet + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index e3732662e4..48578dc5d3 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin from ..attention_processor import ( @@ -397,6 +397,7 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.FloatTensor, @@ -405,21 +406,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - height, width = hidden_states.shape[-2:] # Apply patch embedding, timestep embedding, and project the caption embeddings. @@ -486,10 +472,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) ) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index 14b38cd46c..f48bbc4ef5 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -20,7 +20,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, AttentionMixin, FeedForward from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 @@ -363,6 +363,7 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -374,21 +375,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding @@ -454,10 +440,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA ) output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py index be20b0a3ea..9458712dda 100644 --- a/src/diffusers/models/transformers/consisid_transformer_3d.py +++ b/src/diffusers/models/transformers/consisid_transformer_3d.py @@ -20,7 +20,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention, AttentionMixin, FeedForward from ..attention_processor import CogVideoXAttnProcessor2_0 @@ -620,6 +620,7 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd ] ) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -632,21 +633,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd id_vit_hidden: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # fuse clip and insightface valid_face_emb = None if self.is_train_face: @@ -720,10 +706,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index d45badb1b1..2bee436daa 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -20,7 +20,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin from ..attention_processor import ( Attention, @@ -414,6 +414,7 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -426,21 +427,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -527,10 +513,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py index d54679306e..05c8a6f27b 100644 --- a/src/diffusers/models/transformers/transformer_bria.py +++ b/src/diffusers/models/transformers/transformer_bria.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -581,6 +581,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -621,20 +622,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) @@ -715,10 +702,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py index 09f7961932..d8a879d66d 100644 --- a/src/diffusers/models/transformers/transformer_bria_fibo.py +++ b/src/diffusers/models/transformers/transformer_bria_fibo.py @@ -23,6 +23,7 @@ from ...models.modeling_utils import ModelMixin from ...models.transformers.transformer_bria import BriaAttnProcessor from ...utils import ( USE_PEFT_BACKEND, + apply_lora_scale, logging, scale_lora_layers, unscale_lora_layers, @@ -510,6 +511,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From ] self.caption_projection = nn.ModuleList(caption_projection) + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index 64e9a538a7..c77494c2b1 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention @@ -703,6 +703,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -718,21 +719,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] ] = None, ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, height, width = hidden_states.shape # 1. RoPE @@ -779,10 +765,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p) output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index 9cadfcefc4..e0ba3b21e5 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, apply_lora_scale, logging, scale_lora_layers, unscale_lora_layers from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -774,6 +774,7 @@ class Flux2Transformer2DModel( self.gradient_checkpointing = False + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 6f7ed2fca1..c12bd59c91 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -595,6 +595,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py index 4a5aee29ab..0e9c947b51 100644 --- a/src/diffusers/models/transformers/transformer_hidream_image.py +++ b/src/diffusers/models/transformers/transformer_hidream_image.py @@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...models.modeling_outputs import Transformer2DModelOutput from ...models.modeling_utils import ModelMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, deprecate, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import Attention from ..embeddings import TimestepEmbedding, Timesteps @@ -773,6 +773,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, return hidden_states, hidden_states_masks, img_sizes, img_ids + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -808,21 +809,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)" ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # spatial forward batch_size = hidden_states.shape[0] hidden_states_type = hidden_states.dtype @@ -933,10 +919,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, if hidden_states_masks is not None: hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len] - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 77121edb9f..ecc70b06ef 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import LuminaFeedForward from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed @@ -455,6 +455,7 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -464,21 +465,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # 1. Condition, positional & patch embedding batch_size, _, height, width = hidden_states.shape @@ -539,10 +525,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO ) output = torch.stack(output, dim=0) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 63911fe7c1..d0cf82dea5 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -21,7 +21,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import MochiAttention, MochiAttnProcessor2_0 @@ -404,6 +404,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -413,21 +414,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> torch.Tensor: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - batch_size, num_channels, num_frames, height, width = hidden_states.shape p = self.config.patch_size @@ -479,10 +465,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5) output = hidden_states.reshape(batch_size, -1, num_frames, height, width) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index a4f9034263..f6a3aff3a4 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -21,7 +21,7 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import apply_lora_scale, logging from ..attention import AttentionMixin from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention @@ -570,6 +570,7 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro self.gradient_checkpointing = False + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -582,21 +583,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -695,10 +681,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - if not return_dict: return (output,) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 05391e047b..129f2c59a2 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -18,7 +18,7 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, apply_lora_scale, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, FeedForward, JointTransformerBlock from ..attention_processor import ( @@ -245,6 +245,7 @@ class SD3Transformer2DModel( if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("joint_attention_kwargs") def forward( self, hidden_states: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index e669aa51a5..f9d1621f44 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -20,7 +20,15 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + apply_lora_scale, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ..activations import get_activation from ..attention import AttentionMixin from ..attention_processor import ( @@ -974,6 +982,7 @@ class UNet2DConditionModel( encoder_hidden_states = (encoder_hidden_states, image_embeds) return encoder_hidden_states + @apply_lora_scale("cross_attention_kwargs") def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py index 5a93541501..a96afa3393 100644 --- a/src/diffusers/models/unets/unet_motion_model.py +++ b/src/diffusers/models/unets/unet_motion_model.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin -from ...utils import BaseOutput, deprecate, logging +from ...utils import BaseOutput, apply_lora_scale, deprecate, logging from ...utils.torch_utils import apply_freeu from ..attention import AttentionMixin, BasicTransformerBlock from ..attention_processor import ( @@ -1875,6 +1875,7 @@ class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLo if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) + @apply_lora_scale("cross_attention_kwargs") def forward( self, sample: torch.Tensor, diff --git a/src/diffusers/models/unets/uvit_2d.py b/src/diffusers/models/unets/uvit_2d.py index 4c99ef88ca..836d41a7f9 100644 --- a/src/diffusers/models/unets/uvit_2d.py +++ b/src/diffusers/models/unets/uvit_2d.py @@ -21,6 +21,7 @@ from torch.utils.checkpoint import checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin +from ...utils import apply_lora_scale from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -146,6 +147,7 @@ class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin): self.gradient_checkpointing = False + @apply_lora_scale("cross_attention_kwargs") def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)