From 3cb66e87865a4ac725f7fd77d0960f1f2ad326f7 Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 10 Jun 2025 09:27:43 +0530 Subject: [PATCH] update --- src/diffusers/models/attention.py | 259 +++++++++--------- .../models/transformers/transformer_flux.py | 2 + 2 files changed, 130 insertions(+), 131 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 33c267a29e..02d17e8e95 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -107,14 +107,10 @@ class AttentionMixin: are fused. For cross-attention modules, key and value projection matrices are fused. """ - self.original_attn_processors = None - for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - self.original_attn_processors = self.attn_processors - for module in self.modules(): if isinstance(module, AttentionModuleMixin): module.fuse_projections(fuse=True) @@ -129,30 +125,58 @@ class AttentionMixin: """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) + for _, attn_processor in self.attn_processors.items(): + attn_processor.fused_projections = False class AttentionModuleMixin: - """ - A mixin class that provides common methods for attention modules. - - This mixin adds functionality to set different attention processors, handle attention masks, compute attention - scores, and manage projections. - """ - - # Default processor classes to be overridden by subclasses - default_processor_cls = None + _default_processor_cls = None _available_processors = [] - fused_projections = False - is_cross_attention = False - def _get_compatible_processor(self, backend): - for processor_cls in self._available_processors: - if backend in processor_cls.compatible_backends: - processor = processor_cls() - return processor + def set_processor(self, processor: "AttnProcessor") -> None: + """ + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": + """ + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + def set_attention_backend(self, backend: str): + from .attention_dispatch import AttentionBackendName + + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend.lower()) + self.processor._attention_backend = backend def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: """ @@ -161,14 +185,12 @@ class AttentionModuleMixin: Args: use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not. """ - processor = self.default_processor_cls() if use_npu_flash_attention: if not is_torch_npu_available(): raise ImportError("torch_npu is not available") - processor = self._get_compatible_processor("npu") - self.set_processor(processor) + self.set_attention_backend("_native_npu") def set_use_xla_flash_attention( self, @@ -187,76 +209,11 @@ class AttentionModuleMixin: is_flux (`bool`, *optional*, defaults to `False`): Whether the model is a Flux model. """ - processor = self.default_processor_cls() if use_xla_flash_attention: if not is_torch_xla_available(): raise ImportError("torch_xla is not available") - processor = self._get_compatible_processor("xla") - self.set_processor(processor) - - @torch.no_grad() - def fuse_projections(self, fuse=True): - """ - Fuse the query, key, and value projections into a single projection for efficiency. - - Args: - fuse (`bool`): Whether to fuse the projections or not. - """ - # Skip if already in desired state - if getattr(self, "fused_projections", False) == fuse: - return - - device = self.to_q.weight.data.device - dtype = self.to_q.weight.data.dtype - - if not self.is_cross_attention: - # Fuse self-attention projections - concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_qkv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) - self.to_qkv.bias.copy_(concatenated_bias) - - else: - # Fuse cross-attention key-value projections - concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) - self.to_kv.weight.copy_(concatenated_weights) - if self.use_bias: - concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) - self.to_kv.bias.copy_(concatenated_bias) - - # Handle added projections for models like SD3, Flux, etc. - if ( - getattr(self, "add_q_proj", None) is not None - and getattr(self, "add_k_proj", None) is not None - and getattr(self, "add_v_proj", None) is not None - ): - concatenated_weights = torch.cat( - [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] - ) - in_features = concatenated_weights.shape[1] - out_features = concatenated_weights.shape[0] - - self.to_added_qkv = nn.Linear( - in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype - ) - self.to_added_qkv.weight.copy_(concatenated_weights) - if self.added_proj_bias: - concatenated_bias = torch.cat( - [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] - ) - self.to_added_qkv.bias.copy_(concatenated_bias) - - self.fused_projections = fuse + self.set_attention_backend("_native_xla") def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None @@ -295,13 +252,87 @@ class AttentionModuleMixin: except Exception as e: raise e - processor = self._get_compatible_processor("xformers") - else: - # Set default processor - processor = self.default_processor_cls() + self.set_attention_backend("xformers") - if processor is not None: - self.set_processor(processor) + @torch.no_grad() + def fuse_projections(self): + """ + Fuse the query, key, and value projections into a single projection for efficiency. + """ + # Skip if already fused + if getattr(self, "fused_projections", False): + return + + device = self.to_q.weight.data.device + dtype = self.to_q.weight.data.dtype + + if hasattr(self, "is_cross_attention") and self.is_cross_attention: + # Fuse cross-attention key-value projections + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_kv.weight.copy_(concatenated_weights) + if hasattr(self, "use_bias") and self.use_bias: + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + self.to_kv.bias.copy_(concatenated_bias) + else: + # Fuse self-attention projections + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) + self.to_qkv.weight.copy_(concatenated_weights) + if hasattr(self, "use_bias") and self.use_bias: + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + self.to_qkv.bias.copy_(concatenated_bias) + + # Handle added projections for models like SD3, Flux, etc. + if ( + getattr(self, "add_q_proj", None) is not None + and getattr(self, "add_k_proj", None) is not None + and getattr(self, "add_v_proj", None) is not None + ): + concatenated_weights = torch.cat( + [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data] + ) + in_features = concatenated_weights.shape[1] + out_features = concatenated_weights.shape[0] + + self.to_added_qkv = nn.Linear( + in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype + ) + self.to_added_qkv.weight.copy_(concatenated_weights) + if self.added_proj_bias: + concatenated_bias = torch.cat( + [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] + ) + self.to_added_qkv.bias.copy_(concatenated_bias) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + """ + Unfuse the query, key, and value projections back to separate projections. + """ + # Skip if not fused + if not getattr(self, "fused_projections", False): + return + + # Remove fused projection layers + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + + if hasattr(self, "to_added_qkv"): + delattr(self, "to_added_qkv") + + self.fused_projections = False def set_attention_slice(self, slice_size: int) -> None: """ @@ -326,40 +357,6 @@ class AttentionModuleMixin: self.set_processor(processor) - def set_processor(self, processor: "AttnProcessor") -> None: - """ - Set the attention processor to use. - - Args: - processor (`AttnProcessor`): - The attention processor to use. - """ - # if current processor is in `self._modules` and if passed `processor` is not, we need to - # pop `processor` from `self._modules` - if ( - hasattr(self, "processor") - and isinstance(self.processor, torch.nn.Module) - and not isinstance(processor, torch.nn.Module) - ): - logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") - self._modules.pop("processor") - - self.processor = processor - - def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": - """ - Get the attention processor in use. - - Args: - return_deprecated_lora (`bool`, *optional*, defaults to `False`): - Set to `True` to return the deprecated LoRA attention processor. - - Returns: - "AttentionProcessor": The attention processor in use. - """ - if not return_deprecated_lora: - return self.processor - def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: """ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 5a0309b260..919a158880 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -366,6 +366,8 @@ class FluxTransformerBlock(nn.Module): cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, + qk_norm=qk_norm, + eps=eps, dropout=0.0, bias=True, added_kv_proj_dim=dim,