1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

use dispatch_attention_fn for multiple attention backend support

This commit is contained in:
DavidBert
2025-10-17 15:21:24 +00:00
parent 836cd12a18
commit 0ef0dc6837
7 changed files with 85 additions and 81 deletions

View File

@@ -13,9 +13,6 @@ from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon import PhotonPipeline
@@ -74,14 +71,12 @@ def create_parameter_mapping(depth: int) -> dict:
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
# Old: qk_norm.query_norm / qk_norm.key_norm -> New: norm_q / norm_k
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
# K norm for text tokens moved to attention module
# Old: k_norm -> New: norm_added_k
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
@@ -306,7 +301,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)"
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
)
parser.add_argument(

View File

@@ -516,6 +516,7 @@ else:
"MusicLDMPipeline",
"OmniGenPipeline",
"PaintByExamplePipeline",
"PhotonPipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
@@ -1180,6 +1181,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MusicLDMPipeline,
OmniGenPipeline,
PaintByExamplePipeline,
PhotonPipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,

View File

@@ -191,6 +191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PhotonTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageTransformer2DModel,

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor, nn
@@ -21,7 +21,7 @@ from torch.nn.functional import fold, unfold
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_processor import Attention
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -35,7 +35,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev
r"""
Generates 2D patch coordinate indices for a batch of images.
Parameters:
Args:
batch_size (`int`):
Number of images in the batch.
height (`int`):
@@ -63,7 +63,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
r"""
Applies rotary positional embeddings (RoPE) to a query tensor.
Parameters:
Args:
xq (`torch.Tensor`):
Input tensor of shape `(..., dim)` representing the queries.
freqs_cis (`torch.Tensor`):
@@ -82,11 +82,12 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
class PhotonAttnProcessor2_0:
r"""
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with
diffusers Attention module while handling Photon-specific logic.
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
@@ -104,7 +105,7 @@ class PhotonAttnProcessor2_0:
"""
Apply Photon attention using PhotonAttention module.
Parameters:
Args:
attn: PhotonAttention module containing projection layers
hidden_states: Image tokens [B, L_img, D]
encoder_hidden_states: Text tokens [B, L_txt, D]
@@ -113,9 +114,7 @@ class PhotonAttnProcessor2_0:
"""
if encoder_hidden_states is None:
raise ValueError(
"PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens."
)
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
# Project image tokens to Q, K, V
img_qkv = attn.img_qkv_proj(hidden_states)
@@ -164,14 +163,24 @@ class PhotonAttnProcessor2_0:
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
# Apply scaled dot-product attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor
# Apply attention using dispatch_attention_fn for backend support
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
query = img_q.transpose(1, 2) # [B, L_img, H, D]
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
attn_output = dispatch_attention_fn(
query,
key,
value,
attn_mask=attn_mask_tensor,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape from [B, H, L_img, D] to [B, L_img, H*D]
batch_size, num_heads, seq_len, head_dim = attn_output.shape
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim)
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
batch_size, seq_len, num_heads, head_dim = attn_output.shape
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
# Apply output projection
attn_output = attn.to_out[0](attn_output)
@@ -183,8 +192,8 @@ class PhotonAttnProcessor2_0:
class PhotonAttention(nn.Module, AttentionModuleMixin):
r"""
Photon-style attention module that handles multi-source tokens and RoPE.
Similar to FluxAttention but adapted for Photon's architecture.
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
Photon's architecture.
"""
_default_processor_cls = PhotonAttnProcessor2_0
@@ -242,14 +251,14 @@ class PhotonAttention(nn.Module, AttentionModuleMixin):
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PhotoEmbedND(nn.Module):
class PhotonEmbedND(nn.Module):
r"""
N-dimensional rotary positional embedding.
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
dimension. The embeddings are combined and returned as a single tensor
Parameters:
Args:
dim (int):
Base embedding dimension (must be even).
theta (int):
@@ -258,7 +267,7 @@ class PhotoEmbedND(nn.Module):
List of embedding dimensions for each axis (each must be even).
"""
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
@@ -288,7 +297,7 @@ class MLPEmbedder(nn.Module):
r"""
A simple 2-layer MLP used for embedding inputs.
Parameters:
Args:
in_dim (`int`):
Dimensionality of the input features.
hidden_dim (`int`):
@@ -316,7 +325,7 @@ class Modulation(nn.Module):
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
two tuples `(shift, scale, gate)`.
Parameters:
Args:
dim (`int`):
Dimensionality of the input vector. The output will have `6 * dim` features internally.
@@ -340,7 +349,7 @@ class PhotonBlock(nn.Module):
r"""
Multimodal transformer block with textimage cross-attention, modulation, and MLP.
Parameters:
Args:
hidden_size (`int`):
Dimension of the hidden representations.
num_heads (`int`):
@@ -421,7 +430,7 @@ class PhotonBlock(nn.Module):
r"""
Runs modulation-gated cross-attention and MLP, with residual connections.
Parameters:
Args:
hidden_states (`torch.Tensor`):
Image tokens of shape `(B, L_img, hidden_size)`.
encoder_hidden_states (`torch.Tensor`):
@@ -468,7 +477,7 @@ class FinalLayer(nn.Module):
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
outputs.
Parameters:
Args:
hidden_size (`int`):
Dimensionality of the input tokens.
patch_size (`int`):
@@ -505,7 +514,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor:
r"""
Flattens an image tensor into a sequence of non-overlapping patches.
Parameters:
Args:
img (`torch.Tensor`):
Input image tensor of shape `(B, C, H, W)`.
patch_size (`int`):
@@ -523,7 +532,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
r"""
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
Parameters:
Args:
seq (`torch.Tensor`):
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
patch_size)`.
@@ -550,7 +559,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
Transformer-based 2D model for text to image generation.
Parameters:
Args:
in_channels (`int`, *optional*, defaults to 16):
Number of input channels in the latent image.
patch_size (`int`, *optional*, defaults to 2):
@@ -650,7 +659,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
@@ -683,11 +692,10 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
def forward(
self,
image_latent: Tensor,
hidden_states: Tensor,
timestep: Tensor,
cross_attn_conditioning: Tensor,
micro_conditioning: Tensor,
cross_attn_mask: None | Tensor = None,
encoder_hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
@@ -697,16 +705,14 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
Parameters:
image_latent (`torch.Tensor`):
Args:
hidden_states (`torch.Tensor`):
Input latent image tensor of shape `(B, C, H, W)`.
timestep (`torch.Tensor`):
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
cross_attn_conditioning (`torch.Tensor`):
encoder_hidden_states (`torch.Tensor`):
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
micro_conditioning (`torch.Tensor`):
Extra conditioning vector (currently unused, reserved for future use).
cross_attn_mask (`torch.Tensor`, *optional*):
attention_mask (`torch.Tensor`, *optional*):
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
attention_kwargs (`dict`, *optional*):
Additional arguments passed to attention layers.
@@ -719,15 +725,15 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
"""
# Process text conditioning
txt = self.txt_in(cross_attn_conditioning)
txt = self.txt_in(encoder_hidden_states)
# Convert image to sequence and embed
img = img2seq(image_latent, self.patch_size)
img = img2seq(hidden_states, self.patch_size)
img = self.img_in(img)
# Generate positional embeddings
bs, _, h, w = image_latent.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
bs, _, h, w = hidden_states.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
pe = self.pe_embedder(img_ids)
# Compute time embedding
@@ -742,7 +748,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
txt,
vec,
pe,
cross_attn_mask,
attention_mask,
)
else:
img = block(
@@ -750,12 +756,12 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=pe,
attention_mask=cross_attn_mask,
attention_mask=attention_mask,
)
# Final layer and convert back to image
img = self.final_layer(img, vec)
output = seq2img(img, self.patch_size, image_latent.shape)
output = seq2img(img, self.patch_size, hidden_states.shape)
if not return_dict:
return (output,)

View File

@@ -718,6 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGPipeline,
)
from .paint_by_example import PaintByExamplePipeline
from .photon import PhotonPipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import (

View File

@@ -206,11 +206,11 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import PhotonPipeline
>>> # Load pipeline with from_pretrained
>>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint")
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
>>> pipe.to("cuda")
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("photon_output.png")
```
"""
@@ -717,11 +717,10 @@ class PhotonPipeline(
# Forward through transformer
noise_pred = self.transformer(
image_latent=latents_in,
hidden_states=latents_in,
timestep=t_cont,
cross_attn_conditioning=ca_embed,
micro_conditioning=None,
cross_attn_mask=ca_mask,
encoder_hidden_states=ca_embed,
attention_mask=ca_mask,
return_dict=False,
)[0]

View File

@@ -68,28 +68,28 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
tokenizer.model_max_length = 64
torch.manual_seed(0)
encoder_params = dict(
vocab_size=tokenizer.vocab_size,
hidden_size=8,
intermediate_size=16,
num_hidden_layers=1,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=4,
max_position_embeddings=64,
layer_types=["full_attention"],
attention_bias=False,
attention_dropout=0.0,
dropout_rate=0.0,
hidden_activation="gelu_pytorch_tanh",
rms_norm_eps=1e-06,
attn_logit_softcapping=50.0,
final_logit_softcapping=30.0,
query_pre_attn_scalar=4,
rope_theta=10000.0,
sliding_window=4096,
)
encoder_params = {
"vocab_size": tokenizer.vocab_size,
"hidden_size": 8,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 4,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"attention_bias": False,
"attention_dropout": 0.0,
"dropout_rate": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
"rms_norm_eps": 1e-06,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 4,
"rope_theta": 10000.0,
"sliding_window": 4096,
}
encoder_config = T5GemmaModuleConfig(**encoder_params)
text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
text_encoder = T5GemmaEncoder(text_encoder_config)