mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[refactor] Flux/Chroma single file implementation + Attention Dispatcher (#11916)
* update * update * add coauthor Co-Authored-By: Dhruv Nair <dhruv.nair@gmail.com> * improve test * handle ip adapter params correctly * fix chroma qkv fusion test * fix fastercache implementation * fix more tests * fight more tests * add back set_attention_backend * update * update * make style * make fix-copies * make ip adapter processor compatible with attention dispatcher * refactor chroma as well * remove rmsnorm assert * minify and deprecate npu/xla processors --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -163,6 +163,7 @@ else:
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AttentionBackendName",
|
||||
"AuraFlowTransformer2DModel",
|
||||
"AutoencoderDC",
|
||||
"AutoencoderKL",
|
||||
@@ -238,6 +239,7 @@ else:
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
)
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
@@ -815,6 +817,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
AttentionBackendName,
|
||||
AuraFlowTransformer2DModel,
|
||||
AutoencoderDC,
|
||||
AutoencoderKL,
|
||||
@@ -889,6 +892,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
attention_backend,
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ComponentsManager,
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import AttentionModuleMixin
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..models.modeling_outputs import Transformer2DModelOutput
|
||||
from ..utils import logging
|
||||
@@ -567,7 +568,7 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
|
||||
_apply_faster_cache_on_denoiser(module, config)
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
|
||||
continue
|
||||
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
|
||||
_apply_faster_cache_on_attention_class(name, submodule, config)
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import AttentionModuleMixin
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
@@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
|
||||
config.spatial_attention_block_skip_range = 2
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES):
|
||||
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
|
||||
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
|
||||
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
|
||||
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
|
||||
|
||||
@@ -40,8 +40,6 @@ if is_transformers_available():
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
@@ -867,6 +865,9 @@ class FluxIPAdapterMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
|
||||
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
|
||||
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
@@ -886,9 +887,9 @@ class FluxIPAdapterMixin:
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_processor_class = FluxAttnProcessor()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
@@ -86,9 +86,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
@@ -120,7 +118,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
attn_processor_class = FluxIPAdapterAttnProcessor
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
|
||||
@@ -26,6 +26,7 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
||||
@@ -112,6 +113,7 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_dispatch import AttentionBackendName, attention_backend
|
||||
from .auto_model import AutoModel
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
|
||||
@@ -11,23 +11,504 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers as xops
|
||||
else:
|
||||
xops = None
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AttentionMixin:
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
"""
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
module.fuse_projections()
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
module.unfuse_projections()
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
_default_processor_cls = None
|
||||
_available_processors = []
|
||||
fused_projections = False
|
||||
|
||||
def set_processor(self, processor: AttentionProcessor) -> None:
|
||||
"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
"""
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, torch.nn.Module)
|
||||
and not isinstance(processor, torch.nn.Module)
|
||||
):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
||||
"""
|
||||
Get the attention processor in use.
|
||||
|
||||
Args:
|
||||
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to return the deprecated LoRA attention processor.
|
||||
|
||||
Returns:
|
||||
"AttentionProcessor": The attention processor in use.
|
||||
"""
|
||||
if not return_deprecated_lora:
|
||||
return self.processor
|
||||
|
||||
def set_attention_backend(self, backend: str):
|
||||
from .attention_dispatch import AttentionBackendName
|
||||
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
if backend not in available_backends:
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend.lower())
|
||||
self.processor._attention_backend = backend
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
"""
|
||||
Set whether to use NPU flash attention from `torch_npu` or not.
|
||||
|
||||
Args:
|
||||
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
||||
"""
|
||||
|
||||
if use_npu_flash_attention:
|
||||
if not is_torch_npu_available():
|
||||
raise ImportError("torch_npu is not available")
|
||||
|
||||
self.set_attention_backend("_native_npu")
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
use_xla_flash_attention: bool,
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
"""
|
||||
Set whether to use XLA flash attention from `torch_xla` or not.
|
||||
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
"""
|
||||
if use_xla_flash_attention:
|
||||
if not is_torch_xla_available():
|
||||
raise ImportError("torch_xla is not available")
|
||||
|
||||
self.set_attention_backend("_native_xla")
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set whether to use memory efficient attention from `xformers` or not.
|
||||
|
||||
Args:
|
||||
use_memory_efficient_attention_xformers (`bool`):
|
||||
Whether to use memory efficient attention from `xformers` or not.
|
||||
attention_op (`Callable`, *optional*):
|
||||
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
||||
`xformers`.
|
||||
"""
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||
" only available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
if is_xformers_available():
|
||||
dtype = None
|
||||
if attention_op is not None:
|
||||
op_fw, op_bw = attention_op
|
||||
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||
_ = xops.memory_efficient_attention(q, q, q)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.set_attention_backend("xformers")
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self):
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
"""
|
||||
# Skip if already fused
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
|
||||
# Fuse cross-attention key-value projections
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if hasattr(self, "use_bias") and self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
else:
|
||||
# Fuse self-attention projections
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if hasattr(self, "use_bias") and self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
# Handle added projections for models like SD3, Flux, etc.
|
||||
if (
|
||||
getattr(self, "add_q_proj", None) is not None
|
||||
and getattr(self, "add_k_proj", None) is not None
|
||||
and getattr(self, "add_v_proj", None) is not None
|
||||
):
|
||||
concatenated_weights = torch.cat(
|
||||
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
||||
)
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
"""
|
||||
Unfuse the query, key, and value projections back to separate projections.
|
||||
"""
|
||||
# Skip if not fused
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
# Remove fused projection layers
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
|
||||
if hasattr(self, "to_added_qkv"):
|
||||
delattr(self, "to_added_qkv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def set_attention_slice(self, slice_size: int) -> None:
|
||||
"""
|
||||
Set the slice size for attention computation.
|
||||
|
||||
Args:
|
||||
slice_size (`int`):
|
||||
The slice size for attention computation.
|
||||
"""
|
||||
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
|
||||
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
||||
|
||||
processor = None
|
||||
|
||||
# Try to get a compatible processor for sliced attention
|
||||
if slice_size is not None:
|
||||
processor = self._get_compatible_processor("sliced")
|
||||
|
||||
# If no processor was found or slice_size is None, use default processor
|
||||
if processor is None:
|
||||
processor = self.default_processor_cls()
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
||||
"""
|
||||
Reshape the tensor for multi-head attention processing.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if tensor.ndim == 3:
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
extra_dim = 1
|
||||
else:
|
||||
batch_size, extra_dim, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
|
||||
if out_dim == 3:
|
||||
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
||||
|
||||
return tensor
|
||||
|
||||
def get_attention_scores(
|
||||
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the attention scores.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): The query tensor.
|
||||
key (`torch.Tensor`): The key tensor.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The attention probabilities/scores.
|
||||
"""
|
||||
dtype = query.dtype
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
if attention_mask is None:
|
||||
baddbmm_input = torch.empty(
|
||||
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
beta = 0
|
||||
else:
|
||||
baddbmm_input = attention_mask
|
||||
beta = 1
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
baddbmm_input,
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=beta,
|
||||
alpha=self.scale,
|
||||
)
|
||||
del baddbmm_input
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
del attention_scores
|
||||
|
||||
attention_probs = attention_probs.to(dtype)
|
||||
|
||||
return attention_probs
|
||||
|
||||
def prepare_attention_mask(
|
||||
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare the attention mask for the attention computation.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
||||
target_length (`int`): The target length of the attention mask.
|
||||
batch_size (`int`): The batch size for repeating the attention mask.
|
||||
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The prepared attention mask.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
current_length: int = attention_mask.shape[-1]
|
||||
if current_length != target_length:
|
||||
if attention_mask.device.type == "mps":
|
||||
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
||||
# Instead, we can manually construct the padding tensor.
|
||||
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
||||
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
||||
else:
|
||||
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
||||
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
||||
# remaining_length: int = target_length - current_length
|
||||
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
|
||||
if out_dim == 3:
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
elif out_dim == 4:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||
|
||||
return attention_mask
|
||||
|
||||
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize the encoder hidden states.
|
||||
|
||||
Args:
|
||||
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The normalized encoder hidden states.
|
||||
"""
|
||||
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
||||
if isinstance(self.norm_cross, nn.LayerNorm):
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
elif isinstance(self.norm_cross, nn.GroupNorm):
|
||||
# Group norm norms along the channels dimension and expects
|
||||
# input to be in the shape of (N, C, *). In this case, we want
|
||||
# to norm along the hidden dimension, so we need to move
|
||||
# (batch_size, sequence_length, hidden_size) ->
|
||||
# (batch_size, hidden_size, sequence_length)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
else:
|
||||
assert False
|
||||
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
|
||||
1155
src/diffusers/models/attention_dispatch.py
Normal file
1155
src/diffusers/models/attention_dispatch.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2272,558 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
hidden_states_query_proj = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = torch.zeros_like(hidden_states)
|
||||
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
ip_attn_output += scale * current_ip_hidden_states
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@@ -3453,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XLAFluxFlashAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
||||
"""
|
||||
|
||||
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
if is_torch_xla_version("<", "2.3"):
|
||||
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
||||
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
||||
self.partition_spec = partition_spec
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query /= math.sqrt(head_dim)
|
||||
hidden_states = flash_attention(query, key, value, causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiVaeAttnProcessor2_0:
|
||||
r"""
|
||||
Attention processor used in Mochi VAE.
|
||||
@@ -5992,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
|
||||
pass
|
||||
|
||||
|
||||
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
|
||||
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
||||
super().__init__()
|
||||
|
||||
|
||||
class SanaLinearAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product linear attention.
|
||||
@@ -6167,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
|
||||
deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
return FluxAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxSingleAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
|
||||
deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
return FluxAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
|
||||
deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
return FluxAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxIPAdapterJointAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
|
||||
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
return FluxIPAdapterAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0_NPU:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
||||
"alternative solution to use NPU Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
processor = FluxAttnProcessor()
|
||||
processor._attention_backend = "_native_npu"
|
||||
return processor
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0_NPU:
|
||||
def __new__(self):
|
||||
deprecation_message = (
|
||||
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
||||
"alternative solution to use NPU Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
processor = FluxAttnProcessor()
|
||||
processor._attention_backend = "_fused_npu"
|
||||
return processor
|
||||
|
||||
|
||||
class XLAFluxFlashAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
|
||||
"alternative solution to using XLA Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
if is_torch_xla_version("<", "2.3"):
|
||||
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
||||
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
|
||||
deprecation_message = (
|
||||
"partition_spec was not used in the processor implementation when it was added. Passing it "
|
||||
"is a no-op and support for it will be removed."
|
||||
)
|
||||
deprecate("partition_spec", "1.0.0", deprecation_message)
|
||||
|
||||
processor = FluxAttnProcessor(*args, **kwargs)
|
||||
processor._attention_backend = "_native_xla"
|
||||
return processor
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
|
||||
@@ -1181,6 +1181,7 @@ def apply_rotary_emb(
|
||||
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
sequence_dim: int = 2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
@@ -1198,8 +1199,15 @@ def apply_rotary_emb(
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
if sequence_dim == 2:
|
||||
cos = cos[None, None, :, :]
|
||||
sin = sin[None, None, :, :]
|
||||
elif sequence_dim == 1:
|
||||
cos = cos[None, :, None, :]
|
||||
sin = sin[None, :, None, :]
|
||||
else:
|
||||
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
||||
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
@@ -1243,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
||||
return x
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
is_npu = ids.device.type == "npu"
|
||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2624,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
|
||||
projected_image_embeds.append(image_embed)
|
||||
|
||||
return projected_image_embeds
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
|
||||
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_flux import FluxPosEmbed
|
||||
|
||||
return FluxPosEmbed(*args, **kwargs)
|
||||
|
||||
@@ -610,6 +610,56 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
)
|
||||
|
||||
def set_attention_backend(self, backend: str) -> None:
|
||||
"""
|
||||
Set the attention backend for the model.
|
||||
|
||||
Args:
|
||||
backend (`str`):
|
||||
The name of the backend to set. Must be one of the available backends defined in
|
||||
`AttentionBackendName`. Available backends can be found in
|
||||
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
|
||||
attention as backend.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName
|
||||
|
||||
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
backend = backend.lower()
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
if backend not in available_backends:
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
processor._attention_backend = backend
|
||||
|
||||
def reset_attention_backend(self) -> None:
|
||||
"""
|
||||
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
|
||||
the torch native scaled dot product attention.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
processor._attention_backend = None
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
|
||||
from .transformer_flux import FluxAttention, FluxAttnProcessor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module):
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
from ..attention_processor import FluxAttnProcessor2_0_NPU
|
||||
|
||||
deprecation_message = (
|
||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
||||
"should be set explicitly using the `set_attn_processor` method."
|
||||
@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module):
|
||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
processor = FluxAttnProcessor()
|
||||
|
||||
self.attn = Attention(
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module):
|
||||
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
|
||||
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=FluxAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
processor=FluxAttnProcessor(),
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class ChromaTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
FluxTransformer2DLoadersMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux, modified for Chroma.
|
||||
@@ -475,106 +473,6 @@ class ChromaTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -12,28 +12,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
@@ -42,6 +42,307 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = None
|
||||
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
|
||||
encoder_query = encoder_key = encoder_value = (None,)
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
|
||||
|
||||
return query, key, value, encoder_query, encoder_key, encoder_value
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
||||
if attn.fused_projections:
|
||||
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
class FluxAttnProcessor:
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "FluxAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value, attn_mask=attention_mask, backend=self._attention_backend
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "FluxAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
ip_query = query
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP-adapter
|
||||
ip_attn_output = torch.zeros_like(hidden_states)
|
||||
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
|
||||
|
||||
current_ip_hidden_states = dispatch_attention_fn(
|
||||
ip_query,
|
||||
ip_key,
|
||||
ip_value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
ip_attn_output += scale * current_ip_hidden_states
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = FluxAttnProcessor
|
||||
_available_processors = [
|
||||
FluxAttnProcessor,
|
||||
FluxIPAdapterAttnProcessor,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
out_bias: bool = True,
|
||||
eps: float = 1e-5,
|
||||
out_dim: int = None,
|
||||
context_pre_only: Optional[bool] = None,
|
||||
pre_only: bool = False,
|
||||
elementwise_affine: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.dropout = dropout
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = torch.nn.ModuleList([])
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if added_kv_proj_dim is not None:
|
||||
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
@@ -54,6 +355,8 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
from ..attention_processor import FluxAttnProcessor2_0_NPU
|
||||
|
||||
deprecation_message = (
|
||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
||||
"should be set explicitly using the `set_attn_processor` method."
|
||||
@@ -61,17 +364,15 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
processor = FluxAttnProcessor()
|
||||
|
||||
self.attn = Attention(
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
@@ -118,17 +419,15 @@ class FluxTransformerBlock(nn.Module):
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=FluxAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
processor=FluxAttnProcessor(),
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
@@ -152,6 +451,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Attention.
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
@@ -180,7 +480,6 @@ class FluxTransformerBlock(nn.Module):
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
@@ -195,8 +494,45 @@ class FluxTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
is_npu = ids.device.type == "npu"
|
||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
FluxTransformer2DLoadersMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
@@ -292,106 +628,6 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -67,6 +67,9 @@ from .import_utils import (
|
||||
is_bitsandbytes_version,
|
||||
is_bs4_available,
|
||||
is_cosmos_guardrail_available,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_gguf_available,
|
||||
@@ -90,6 +93,8 @@ from .import_utils import (
|
||||
is_peft_version,
|
||||
is_pytorch_retinaface_available,
|
||||
is_safetensors_available,
|
||||
is_sageattention_available,
|
||||
is_sageattention_version,
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_tensorboard_available,
|
||||
@@ -108,6 +113,7 @@ from .import_utils import (
|
||||
is_unidecode_available,
|
||||
is_wandb_available,
|
||||
is_xformers_available,
|
||||
is_xformers_version,
|
||||
requires_backends,
|
||||
)
|
||||
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
|
||||
|
||||
@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
DIFFUSERS_REQUEST_TIMEOUT = 60
|
||||
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
|
||||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -258,6 +258,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AttentionBackendName(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1368,6 +1383,10 @@ class WanVACETransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def attention_backend(*args, **kwargs):
|
||||
requires_backends(attention_backend, ["torch"])
|
||||
|
||||
|
||||
class ComponentsManager(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -220,6 +220,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
|
||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
||||
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -378,6 +381,18 @@ def is_hpu_available():
|
||||
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
||||
|
||||
|
||||
def is_sageattention_available():
|
||||
return _sageattention_available
|
||||
|
||||
|
||||
def is_flash_attn_available():
|
||||
return _flash_attn_available
|
||||
|
||||
|
||||
def is_flash_attn_3_available():
|
||||
return _flash_attn_3_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str):
|
||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||
|
||||
|
||||
def is_xformers_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current xformers version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _xformers_available:
|
||||
return False
|
||||
return compare_versions(parse(_xformers_version), operation, version)
|
||||
|
||||
|
||||
def is_sageattention_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current sageattention version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _sageattention_available:
|
||||
return False
|
||||
return compare_versions(parse(_sageattention_version), operation, version)
|
||||
|
||||
|
||||
def is_flash_attn_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current flash-attention version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _flash_attn_available:
|
||||
return False
|
||||
return compare_versions(parse(_flash_attn_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
@@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class ChromaPipelineFastTests(
|
||||
@@ -126,12 +121,10 @@ class ChromaPipelineFastTests(
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils.testing_utils import floats_tensor, torch_device
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class ChromaImg2ImgPipelineFastTests(
|
||||
@@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests(
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
@@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -28,8 +28,7 @@ from ..test_pipelines_common import (
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
check_qkv_fused_layers_exist,
|
||||
)
|
||||
|
||||
|
||||
@@ -171,12 +170,10 @@ class FluxPipelineFastTests(
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -8,11 +8,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
@@ -140,12 +136,10 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -15,11 +15,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
PipelineTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
check_qkv_fusion_processors_exist,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
@@ -134,12 +130,10 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
|
||||
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
|
||||
# to the pipeline level.
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
assert check_qkv_fusion_processors_exist(pipe.transformer), (
|
||||
"Something wrong with the fused attention processors. Expected all the attention processors to be fused."
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
assert check_qkv_fusion_matches_attn_procs_length(
|
||||
pipe.transformer, pipe.transformer.original_attn_processors
|
||||
), "Something wrong with the attention processors concerning the fused QKV projections."
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
|
||||
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
|
||||
@@ -98,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
|
||||
return all(p.startswith("Fused") for p in proc_names)
|
||||
|
||||
|
||||
def check_qkv_fused_layers_exist(model, layer_names):
|
||||
is_fused_submodules = []
|
||||
for submodule in model.modules():
|
||||
if not isinstance(submodule, AttentionModuleMixin):
|
||||
continue
|
||||
is_fused_attribute_set = submodule.fused_projections
|
||||
is_fused_layer = True
|
||||
for layer in layer_names:
|
||||
is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
|
||||
is_fused = is_fused_attribute_set and is_fused_layer
|
||||
is_fused_submodules.append(is_fused)
|
||||
return all(is_fused_submodules)
|
||||
|
||||
|
||||
class SDFunctionTesterMixin:
|
||||
"""
|
||||
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
|
||||
|
||||
Reference in New Issue
Block a user