1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-06-10 09:27:43 +05:30
parent 2891f14127
commit 3cb66e8786
2 changed files with 130 additions and 131 deletions

View File

@@ -107,14 +107,10 @@ class AttentionMixin:
are fused. For cross-attention modules, key and value projection matrices are fused.
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
module.fuse_projections(fuse=True)
@@ -129,30 +125,58 @@ class AttentionMixin:
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
for _, attn_processor in self.attn_processors.items():
attn_processor.fused_projections = False
class AttentionModuleMixin:
"""
A mixin class that provides common methods for attention modules.
This mixin adds functionality to set different attention processors, handle attention masks, compute attention
scores, and manage projections.
"""
# Default processor classes to be overridden by subclasses
default_processor_cls = None
_default_processor_cls = None
_available_processors = []
fused_projections = False
is_cross_attention = False
def _get_compatible_processor(self, backend):
for processor_cls in self._available_processors:
if backend in processor_cls.compatible_backends:
processor = processor_cls()
return processor
def set_processor(self, processor: "AttnProcessor") -> None:
"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
def set_attention_backend(self, backend: str):
from .attention_dispatch import AttentionBackendName
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend.lower())
self.processor._attention_backend = backend
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
"""
@@ -161,14 +185,12 @@ class AttentionModuleMixin:
Args:
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
"""
processor = self.default_processor_cls()
if use_npu_flash_attention:
if not is_torch_npu_available():
raise ImportError("torch_npu is not available")
processor = self._get_compatible_processor("npu")
self.set_processor(processor)
self.set_attention_backend("_native_npu")
def set_use_xla_flash_attention(
self,
@@ -187,76 +209,11 @@ class AttentionModuleMixin:
is_flux (`bool`, *optional*, defaults to `False`):
Whether the model is a Flux model.
"""
processor = self.default_processor_cls()
if use_xla_flash_attention:
if not is_torch_xla_available():
raise ImportError("torch_xla is not available")
processor = self._get_compatible_processor("xla")
self.set_processor(processor)
@torch.no_grad()
def fuse_projections(self, fuse=True):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
Args:
fuse (`bool`): Whether to fuse the projections or not.
"""
# Skip if already in desired state
if getattr(self, "fused_projections", False) == fuse:
return
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not self.is_cross_attention:
# Fuse self-attention projections
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
else:
# Fuse cross-attention key-value projections
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
# Handle added projections for models like SD3, Flux, etc.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
self.set_attention_backend("_native_xla")
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
@@ -295,13 +252,87 @@ class AttentionModuleMixin:
except Exception as e:
raise e
processor = self._get_compatible_processor("xformers")
else:
# Set default processor
processor = self.default_processor_cls()
self.set_attention_backend("xformers")
if processor is not None:
self.set_processor(processor)
@torch.no_grad()
def fuse_projections(self):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
"""
# Skip if already fused
if getattr(self, "fused_projections", False):
return
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
# Fuse cross-attention key-value projections
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias)
else:
# Fuse self-attention projections
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
if hasattr(self, "use_bias") and self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias)
# Handle added projections for models like SD3, Flux, etc.
if (
getattr(self, "add_q_proj", None) is not None
and getattr(self, "add_k_proj", None) is not None
and getattr(self, "add_v_proj", None) is not None
):
concatenated_weights = torch.cat(
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
)
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
"""
Unfuse the query, key, and value projections back to separate projections.
"""
# Skip if not fused
if not getattr(self, "fused_projections", False):
return
# Remove fused projection layers
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_qkv"):
delattr(self, "to_added_qkv")
self.fused_projections = False
def set_attention_slice(self, slice_size: int) -> None:
"""
@@ -326,40 +357,6 @@ class AttentionModuleMixin:
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor") -> None:
"""
Set the attention processor to use.
Args:
processor (`AttnProcessor`):
The attention processor to use.
"""
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
"""
Get the attention processor in use.
Args:
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to return the deprecated LoRA attention processor.
Returns:
"AttentionProcessor": The attention processor in use.
"""
if not return_deprecated_lora:
return self.processor
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.

View File

@@ -366,6 +366,8 @@ class FluxTransformerBlock(nn.Module):
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm=qk_norm,
eps=eps,
dropout=0.0,
bias=True,
added_kv_proj_dim=dim,