mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
use dispatch_attention_fn for multiple attention backend support
This commit is contained in:
@@ -13,9 +13,6 @@ from typing import Dict, Tuple
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
|
||||
@@ -74,14 +71,12 @@ def create_parameter_mapping(depth: int) -> dict:
|
||||
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
|
||||
|
||||
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
|
||||
# Old: qk_norm.query_norm / qk_norm.key_norm -> New: norm_q / norm_k
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
|
||||
# K norm for text tokens moved to attention module
|
||||
# Old: k_norm -> New: norm_added_k
|
||||
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
|
||||
@@ -306,7 +301,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)"
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -516,6 +516,7 @@ else:
|
||||
"MusicLDMPipeline",
|
||||
"OmniGenPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PhotonPipeline",
|
||||
"PIAPipeline",
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
@@ -1180,6 +1181,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MusicLDMPipeline,
|
||||
OmniGenPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PhotonPipeline,
|
||||
PIAPipeline,
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
|
||||
@@ -191,6 +191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
PhotonTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -21,7 +21,7 @@ from torch.nn.functional import fold, unfold
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_processor import Attention
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..embeddings import get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -35,7 +35,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev
|
||||
r"""
|
||||
Generates 2D patch coordinate indices for a batch of images.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Number of images in the batch.
|
||||
height (`int`):
|
||||
@@ -63,7 +63,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
r"""
|
||||
Applies rotary positional embeddings (RoPE) to a query tensor.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
xq (`torch.Tensor`):
|
||||
Input tensor of shape `(..., dim)` representing the queries.
|
||||
freqs_cis (`torch.Tensor`):
|
||||
@@ -82,11 +82,12 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
|
||||
class PhotonAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with
|
||||
diffusers Attention module while handling Photon-specific logic.
|
||||
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
|
||||
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
@@ -104,7 +105,7 @@ class PhotonAttnProcessor2_0:
|
||||
"""
|
||||
Apply Photon attention using PhotonAttention module.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
attn: PhotonAttention module containing projection layers
|
||||
hidden_states: Image tokens [B, L_img, D]
|
||||
encoder_hidden_states: Text tokens [B, L_txt, D]
|
||||
@@ -113,9 +114,7 @@ class PhotonAttnProcessor2_0:
|
||||
"""
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError(
|
||||
"PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens."
|
||||
)
|
||||
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
|
||||
|
||||
# Project image tokens to Q, K, V
|
||||
img_qkv = attn.img_qkv_proj(hidden_states)
|
||||
@@ -164,14 +163,24 @@ class PhotonAttnProcessor2_0:
|
||||
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
|
||||
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
|
||||
|
||||
# Apply scaled dot-product attention
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor
|
||||
# Apply attention using dispatch_attention_fn for backend support
|
||||
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
|
||||
query = img_q.transpose(1, 2) # [B, L_img, H, D]
|
||||
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
|
||||
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
|
||||
|
||||
attn_output = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask_tensor,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape from [B, H, L_img, D] to [B, L_img, H*D]
|
||||
batch_size, num_heads, seq_len, head_dim = attn_output.shape
|
||||
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim)
|
||||
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
|
||||
batch_size, seq_len, num_heads, head_dim = attn_output.shape
|
||||
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
|
||||
|
||||
# Apply output projection
|
||||
attn_output = attn.to_out[0](attn_output)
|
||||
@@ -183,8 +192,8 @@ class PhotonAttnProcessor2_0:
|
||||
|
||||
class PhotonAttention(nn.Module, AttentionModuleMixin):
|
||||
r"""
|
||||
Photon-style attention module that handles multi-source tokens and RoPE.
|
||||
Similar to FluxAttention but adapted for Photon's architecture.
|
||||
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
|
||||
Photon's architecture.
|
||||
"""
|
||||
|
||||
_default_processor_cls = PhotonAttnProcessor2_0
|
||||
@@ -242,14 +251,14 @@ class PhotonAttention(nn.Module, AttentionModuleMixin):
|
||||
|
||||
|
||||
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class PhotoEmbedND(nn.Module):
|
||||
class PhotonEmbedND(nn.Module):
|
||||
r"""
|
||||
N-dimensional rotary positional embedding.
|
||||
|
||||
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
|
||||
dimension. The embeddings are combined and returned as a single tensor
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
dim (int):
|
||||
Base embedding dimension (must be even).
|
||||
theta (int):
|
||||
@@ -258,7 +267,7 @@ class PhotoEmbedND(nn.Module):
|
||||
List of embedding dimensions for each axis (each must be even).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
@@ -288,7 +297,7 @@ class MLPEmbedder(nn.Module):
|
||||
r"""
|
||||
A simple 2-layer MLP used for embedding inputs.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
in_dim (`int`):
|
||||
Dimensionality of the input features.
|
||||
hidden_dim (`int`):
|
||||
@@ -316,7 +325,7 @@ class Modulation(nn.Module):
|
||||
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
|
||||
two tuples `(shift, scale, gate)`.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
dim (`int`):
|
||||
Dimensionality of the input vector. The output will have `6 * dim` features internally.
|
||||
|
||||
@@ -340,7 +349,7 @@ class PhotonBlock(nn.Module):
|
||||
r"""
|
||||
Multimodal transformer block with text–image cross-attention, modulation, and MLP.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Dimension of the hidden representations.
|
||||
num_heads (`int`):
|
||||
@@ -421,7 +430,7 @@ class PhotonBlock(nn.Module):
|
||||
r"""
|
||||
Runs modulation-gated cross-attention and MLP, with residual connections.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Image tokens of shape `(B, L_img, hidden_size)`.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
@@ -468,7 +477,7 @@ class FinalLayer(nn.Module):
|
||||
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
|
||||
outputs.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Dimensionality of the input tokens.
|
||||
patch_size (`int`):
|
||||
@@ -505,7 +514,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor:
|
||||
r"""
|
||||
Flattens an image tensor into a sequence of non-overlapping patches.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
img (`torch.Tensor`):
|
||||
Input image tensor of shape `(B, C, H, W)`.
|
||||
patch_size (`int`):
|
||||
@@ -523,7 +532,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
|
||||
r"""
|
||||
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
seq (`torch.Tensor`):
|
||||
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
|
||||
patch_size)`.
|
||||
@@ -550,7 +559,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
Transformer-based 2D model for text to image generation.
|
||||
|
||||
Parameters:
|
||||
Args:
|
||||
in_channels (`int`, *optional*, defaults to 16):
|
||||
Number of input channels in the latent image.
|
||||
patch_size (`int`, *optional*, defaults to 2):
|
||||
@@ -650,7 +659,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
@@ -683,11 +692,10 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_latent: Tensor,
|
||||
hidden_states: Tensor,
|
||||
timestep: Tensor,
|
||||
cross_attn_conditioning: Tensor,
|
||||
micro_conditioning: Tensor,
|
||||
cross_attn_mask: None | Tensor = None,
|
||||
encoder_hidden_states: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
@@ -697,16 +705,14 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
|
||||
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
|
||||
|
||||
Parameters:
|
||||
image_latent (`torch.Tensor`):
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input latent image tensor of shape `(B, C, H, W)`.
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
|
||||
cross_attn_conditioning (`torch.Tensor`):
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
|
||||
micro_conditioning (`torch.Tensor`):
|
||||
Extra conditioning vector (currently unused, reserved for future use).
|
||||
cross_attn_mask (`torch.Tensor`, *optional*):
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional arguments passed to attention layers.
|
||||
@@ -719,15 +725,15 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
|
||||
"""
|
||||
# Process text conditioning
|
||||
txt = self.txt_in(cross_attn_conditioning)
|
||||
txt = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Convert image to sequence and embed
|
||||
img = img2seq(image_latent, self.patch_size)
|
||||
img = img2seq(hidden_states, self.patch_size)
|
||||
img = self.img_in(img)
|
||||
|
||||
# Generate positional embeddings
|
||||
bs, _, h, w = image_latent.shape
|
||||
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
|
||||
bs, _, h, w = hidden_states.shape
|
||||
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
|
||||
pe = self.pe_embedder(img_ids)
|
||||
|
||||
# Compute time embedding
|
||||
@@ -742,7 +748,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
cross_attn_mask,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
img = block(
|
||||
@@ -750,12 +756,12 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=pe,
|
||||
attention_mask=cross_attn_mask,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
# Final layer and convert back to image
|
||||
img = self.final_layer(img, vec)
|
||||
output = seq2img(img, self.patch_size, image_latent.shape)
|
||||
output = seq2img(img, self.patch_size, hidden_states.shape)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
@@ -718,6 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .photon import PhotonPipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import (
|
||||
|
||||
@@ -206,11 +206,11 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import PhotonPipeline
|
||||
|
||||
>>> # Load pipeline with from_pretrained
|
||||
>>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint")
|
||||
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
|
||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
|
||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
>>> image.save("photon_output.png")
|
||||
```
|
||||
"""
|
||||
@@ -717,11 +717,10 @@ class PhotonPipeline(
|
||||
|
||||
# Forward through transformer
|
||||
noise_pred = self.transformer(
|
||||
image_latent=latents_in,
|
||||
hidden_states=latents_in,
|
||||
timestep=t_cont,
|
||||
cross_attn_conditioning=ca_embed,
|
||||
micro_conditioning=None,
|
||||
cross_attn_mask=ca_mask,
|
||||
encoder_hidden_states=ca_embed,
|
||||
attention_mask=ca_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
|
||||
@@ -68,28 +68,28 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
tokenizer.model_max_length = 64
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
encoder_params = dict(
|
||||
vocab_size=tokenizer.vocab_size,
|
||||
hidden_size=8,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=1,
|
||||
head_dim=4,
|
||||
max_position_embeddings=64,
|
||||
layer_types=["full_attention"],
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
dropout_rate=0.0,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
rms_norm_eps=1e-06,
|
||||
attn_logit_softcapping=50.0,
|
||||
final_logit_softcapping=30.0,
|
||||
query_pre_attn_scalar=4,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=4096,
|
||||
)
|
||||
|
||||
encoder_params = {
|
||||
"vocab_size": tokenizer.vocab_size,
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 4,
|
||||
"max_position_embeddings": 64,
|
||||
"layer_types": ["full_attention"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"dropout_rate": 0.0,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"rms_norm_eps": 1e-06,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"query_pre_attn_scalar": 4,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
}
|
||||
encoder_config = T5GemmaModuleConfig(**encoder_params)
|
||||
text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
|
||||
text_encoder = T5GemmaEncoder(text_encoder_config)
|
||||
|
||||
Reference in New Issue
Block a user