1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2025-07-08 04:06:22 +02:00
parent c87575dde6
commit 941911c538
7 changed files with 244 additions and 62 deletions

View File

@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, Optional, Tuple, Union
import torch

View File

@@ -1244,31 +1244,21 @@ class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
from .transformers.transformer_flux import FluxPosEmbed as FluxPosEmbed_
deprecate(
"FluxPosEmbed",
"1.0.0",
"Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please use `FluxPosEmbed` from `diffusers.models.transformers.transformer_flux` instead.",
)
self.theta = theta
self.axes_dim = axes_dim
self._rope = FluxPosEmbed_(theta, axes_dim)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
return self._rope(ids)
class TimestepEmbedding(nn.Module):

View File

@@ -599,6 +599,50 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
)
def set_attention_backend(self, backend: str) -> None:
"""
Set the attention backend for the model.
Args:
backend (`str`):
The name of the backend to set. Must be one of the available backends defined in
`AttentionBackendName`. Available backends can be found in
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName
backend = backend.lower()
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
for module in self.modules():
if not isinstance(module, AttentionModuleMixin):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = backend
def reset_attention_backend(self) -> None:
"""
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
the torch native scaled dot product attention.
"""
from .attention_processor import Attention, MochiAttention
attention_classes = (Attention, MochiAttention)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
processor._attention_backend = None
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@@ -24,10 +24,15 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
apply_rotary_emb,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -73,7 +78,6 @@ class FluxAttnProcessor:
"""Public method to get projections based on whether we're using fused mode or not."""
if attn.is_fused and hasattr(attn, "to_qkv"):
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
return self._get_projections(attn, hidden_states, encoder_hidden_states)
def __call__(
@@ -117,17 +121,10 @@ class FluxAttnProcessor:
value = torch.cat([encoder_value, value], dim=2)
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
)
hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -242,12 +239,10 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -292,13 +287,76 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
@maybe_allow_in_graph
class FluxAttention(Attention):
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = FluxAttnProcessor
_available_processors = [
FluxAttnProcessor,
FluxIPAdapterAttnProcessor,
]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
added_proj_bias: Optional[bool] = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
context_pre_only: Optional[bool] = None,
pre_only: bool = False,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
assert qk_norm == "rms_norm", "Flux uses RMSNorm"
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.use_bias = bias
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.added_proj_bias = added_proj_bias
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
if not self.pre_only:
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
if added_kv_proj_dim is not None:
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
@@ -330,20 +388,19 @@ class FluxSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
joint_attention_kwargs = joint_attention_kwargs or {}
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
proj_out = self.proj_out(attn_mlp_hidden_states)
hidden_states = hidden_states + gate.unsqueeze(1) * proj_out
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
@@ -386,12 +443,13 @@ class FluxTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
joint_attention_kwargs = joint_attention_kwargs or {}
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
@@ -410,7 +468,7 @@ class FluxTransformerBlock(nn.Module):
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
@@ -420,21 +478,54 @@ class FluxTransformerBlock(nn.Module):
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (
1 + c_scale_mlp.unsqueeze(1)
) + c_shift_mlp.unsqueeze(1)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
is_npu = ids.device.type == "npu"
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
pos[:, i],
theta=self.theta,
repeat_interleave_real=True,
use_real=True,
freqs_dtype=freqs_dtype,
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class FluxTransformer2DModel(
ModelMixin,
ConfigMixin,
@@ -537,10 +628,6 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
# Using inherited methods from AttentionMixin
# Using inherited methods from AttentionMixin
def forward(
self,
hidden_states: torch.Tensor,
@@ -634,11 +721,7 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
else:
@@ -665,12 +748,7 @@ class FluxTransformer2DModel(
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
)
hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb, image_rotary_emb)
else:
hidden_states = block(

View File

@@ -67,6 +67,9 @@ from .import_utils import (
is_bitsandbytes_version,
is_bs4_available,
is_cosmos_guardrail_available,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_flax_available,
is_ftfy_available,
is_gguf_available,
@@ -90,6 +93,8 @@ from .import_utils import (
is_peft_version,
is_pytorch_retinaface_available,
is_safetensors_available,
is_sageattention_available,
is_sageattention_version,
is_scipy_available,
is_sentencepiece_available,
is_tensorboard_available,
@@ -108,6 +113,7 @@ from .import_utils import (
is_unidecode_available,
is_wandb_available,
is_xformers_available,
is_xformers_version,
requires_backends,
)
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video

View File

@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -219,6 +219,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
def is_torch_available():
@@ -377,6 +380,18 @@ def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
def is_sageattention_available():
return _sageattention_available
def is_flash_attn_available():
return _flash_attn_available
def is_flash_attn_3_available():
return _flash_attn_3_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -803,6 +818,51 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _xformers_available:
return False
return compare_versions(parse(_xformers_version), operation, version)
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _sageattention_available:
return False
return compare_versions(parse(_sageattention_version), operation, version)
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _flash_attn_available:
return False
return compare_versions(parse(_flash_attn_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects