mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -163,6 +163,7 @@ else:
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AttentionBackendName",
|
||||
"AuraFlowTransformer2DModel",
|
||||
"AutoencoderDC",
|
||||
"AutoencoderKL",
|
||||
@@ -237,6 +238,7 @@ else:
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
)
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
@@ -809,6 +811,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
AttentionBackendName,
|
||||
AuraFlowTransformer2DModel,
|
||||
AutoencoderDC,
|
||||
AutoencoderKL,
|
||||
@@ -882,6 +885,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
attention_backend,
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ComponentsManager,
|
||||
|
||||
@@ -26,6 +26,7 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
||||
@@ -111,6 +112,7 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_dispatch import AttentionBackendName, attention_backend
|
||||
from .auto_model import AutoModel
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
|
||||
1155
src/diffusers/models/attention_dispatch.py
Normal file
1155
src/diffusers/models/attention_dispatch.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -606,6 +606,56 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
)
|
||||
|
||||
def set_attention_backend(self, backend: str) -> None:
|
||||
"""
|
||||
Set the attention backend for the model.
|
||||
|
||||
Args:
|
||||
backend (`str`):
|
||||
The name of the backend to set. Must be one of the available backends defined in
|
||||
`AttentionBackendName`. Available backends can be found in
|
||||
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
|
||||
attention as backend.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName
|
||||
|
||||
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
backend = backend.lower()
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
if backend not in available_backends:
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
processor._attention_backend = backend
|
||||
|
||||
def reset_attention_backend(self) -> None:
|
||||
"""
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
|
||||
the torch native scaled dot product attention.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
processor._attention_backend = None
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -26,6 +26,7 @@ from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, un
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
@@ -42,6 +43,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxAttnProcessor:
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
@@ -51,31 +54,25 @@ class FluxAttnProcessor:
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
|
||||
encoder_query = encoder_key = encoder_value = None
|
||||
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
encoder_projections = None
|
||||
encoder_query = encoder_key = encoder_value = (None,)
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None):
|
||||
if hasattr(attn, "to_qkv") and attn.fused_projections:
|
||||
if attn.fused_projections:
|
||||
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
return self._get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
@@ -87,53 +84,43 @@ class FluxAttnProcessor:
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = self.get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
if encoder_projections is not None:
|
||||
encoder_query, encoder_key, encoder_value = encoder_projections
|
||||
encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
# Concatenate for joint attention
|
||||
query = torch.cat([encoder_query, query], dim=2)
|
||||
key = torch.cat([encoder_key, key], dim=2)
|
||||
value = torch.cat([encoder_value, value], dim=2)
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
@@ -146,6 +133,8 @@ class FluxAttnProcessor:
|
||||
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
@@ -241,8 +230,14 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -273,8 +268,14 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
current_ip_hidden_states = dispatch_attention_fn(
|
||||
ip_query,
|
||||
ip_key,
|
||||
ip_value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
@@ -323,6 +324,7 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
@@ -67,6 +67,9 @@ from .import_utils import (
|
||||
is_bitsandbytes_version,
|
||||
is_bs4_available,
|
||||
is_cosmos_guardrail_available,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_gguf_available,
|
||||
@@ -90,6 +93,8 @@ from .import_utils import (
|
||||
is_peft_version,
|
||||
is_pytorch_retinaface_available,
|
||||
is_safetensors_available,
|
||||
is_sageattention_available,
|
||||
is_sageattention_version,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_tensorboard_available,
|
||||
@@ -108,6 +113,7 @@ from .import_utils import (
|
||||
is_unidecode_available,
|
||||
is_wandb_available,
|
||||
is_xformers_available,
|
||||
is_xformers_version,
|
||||
requires_backends,
|
||||
)
|
||||
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
|
||||
|
||||
@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
DIFFUSERS_REQUEST_TIMEOUT = 60
|
||||
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
|
||||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -220,6 +220,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
|
||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
||||
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -378,6 +381,18 @@ def is_hpu_available():
|
||||
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
||||
|
||||
|
||||
def is_sageattention_available():
|
||||
return _sageattention_available
|
||||
|
||||
|
||||
def is_flash_attn_available():
|
||||
return _flash_attn_available
|
||||
|
||||
|
||||
def is_flash_attn_3_available():
|
||||
return _flash_attn_3_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str):
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
def is_xformers_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current xformers version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _xformers_available:
|
||||
return False
|
||||
return compare_versions(parse(_xformers_version), operation, version)
|
||||
|
||||
|
||||
def is_sageattention_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current sageattention version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _sageattention_available:
|
||||
return False
|
||||
return compare_versions(parse(_sageattention_version), operation, version)
|
||||
|
||||
|
||||
def is_flash_attn_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current flash-attention version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _flash_attn_available:
|
||||
return False
|
||||
return compare_versions(parse(_flash_attn_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
Reference in New Issue
Block a user