mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -107,14 +107,10 @@ class AttentionMixin:
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
module.fuse_projections(fuse=True)
|
||||
@@ -129,30 +125,58 @@ class AttentionMixin:
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
attn_processor.fused_projections = False
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
"""
|
||||
A mixin class that provides common methods for attention modules.
|
||||
|
||||
This mixin adds functionality to set different attention processors, handle attention masks, compute attention
|
||||
scores, and manage projections.
|
||||
"""
|
||||
|
||||
# Default processor classes to be overridden by subclasses
|
||||
default_processor_cls = None
|
||||
_default_processor_cls = None
|
||||
_available_processors = []
|
||||
|
||||
fused_projections = False
|
||||
is_cross_attention = False
|
||||
|
||||
def _get_compatible_processor(self, backend):
|
||||
for processor_cls in self._available_processors:
|
||||
if backend in processor_cls.compatible_backends:
|
||||
processor = processor_cls()
|
||||
return processor
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
"""
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, torch.nn.Module)
|
||||
and not isinstance(processor, torch.nn.Module)
|
||||
):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
||||
"""
|
||||
Get the attention processor in use.
|
||||
|
||||
Args:
|
||||
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to return the deprecated LoRA attention processor.
|
||||
|
||||
Returns:
|
||||
"AttentionProcessor": The attention processor in use.
|
||||
"""
|
||||
if not return_deprecated_lora:
|
||||
return self.processor
|
||||
|
||||
def set_attention_backend(self, backend: str):
|
||||
from .attention_dispatch import AttentionBackendName
|
||||
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
if backend not in available_backends:
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend.lower())
|
||||
self.processor._attention_backend = backend
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
"""
|
||||
@@ -161,14 +185,12 @@ class AttentionModuleMixin:
|
||||
Args:
|
||||
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
||||
"""
|
||||
processor = self.default_processor_cls()
|
||||
|
||||
if use_npu_flash_attention:
|
||||
if not is_torch_npu_available():
|
||||
raise ImportError("torch_npu is not available")
|
||||
processor = self._get_compatible_processor("npu")
|
||||
|
||||
self.set_processor(processor)
|
||||
self.set_attention_backend("_native_npu")
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
@@ -187,76 +209,11 @@ class AttentionModuleMixin:
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
"""
|
||||
processor = self.default_processor_cls()
|
||||
if use_xla_flash_attention:
|
||||
if not is_torch_xla_available():
|
||||
raise ImportError("torch_xla is not available")
|
||||
processor = self._get_compatible_processor("xla")
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
|
||||
Args:
|
||||
fuse (`bool`): Whether to fuse the projections or not.
|
||||
"""
|
||||
# Skip if already in desired state
|
||||
if getattr(self, "fused_projections", False) == fuse:
|
||||
return
|
||||
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# Fuse self-attention projections
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
else:
|
||||
# Fuse cross-attention key-value projections
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
|
||||
# Handle added projections for models like SD3, Flux, etc.
|
||||
if (
|
||||
getattr(self, "add_q_proj", None) is not None
|
||||
and getattr(self, "add_k_proj", None) is not None
|
||||
and getattr(self, "add_v_proj", None) is not None
|
||||
):
|
||||
concatenated_weights = torch.cat(
|
||||
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
||||
)
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = fuse
|
||||
self.set_attention_backend("_native_xla")
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
@@ -295,13 +252,87 @@ class AttentionModuleMixin:
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
processor = self._get_compatible_processor("xformers")
|
||||
else:
|
||||
# Set default processor
|
||||
processor = self.default_processor_cls()
|
||||
self.set_attention_backend("xformers")
|
||||
|
||||
if processor is not None:
|
||||
self.set_processor(processor)
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self):
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
"""
|
||||
# Skip if already fused
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
|
||||
# Fuse cross-attention key-value projections
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if hasattr(self, "use_bias") and self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
else:
|
||||
# Fuse self-attention projections
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if hasattr(self, "use_bias") and self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
# Handle added projections for models like SD3, Flux, etc.
|
||||
if (
|
||||
getattr(self, "add_q_proj", None) is not None
|
||||
and getattr(self, "add_k_proj", None) is not None
|
||||
and getattr(self, "add_v_proj", None) is not None
|
||||
):
|
||||
concatenated_weights = torch.cat(
|
||||
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
||||
)
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
"""
|
||||
Unfuse the query, key, and value projections back to separate projections.
|
||||
"""
|
||||
# Skip if not fused
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
# Remove fused projection layers
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
|
||||
if hasattr(self, "to_added_qkv"):
|
||||
delattr(self, "to_added_qkv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def set_attention_slice(self, slice_size: int) -> None:
|
||||
"""
|
||||
@@ -326,40 +357,6 @@ class AttentionModuleMixin:
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
"""
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, torch.nn.Module)
|
||||
and not isinstance(processor, torch.nn.Module)
|
||||
):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
||||
"""
|
||||
Get the attention processor in use.
|
||||
|
||||
Args:
|
||||
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to return the deprecated LoRA attention processor.
|
||||
|
||||
Returns:
|
||||
"AttentionProcessor": The attention processor in use.
|
||||
"""
|
||||
if not return_deprecated_lora:
|
||||
return self.processor
|
||||
|
||||
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
||||
|
||||
@@ -366,6 +366,8 @@ class FluxTransformerBlock(nn.Module):
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
added_kv_proj_dim=dim,
|
||||
|
||||
Reference in New Issue
Block a user