From 0ef0dc6837a4825ec4457b39d09f2c31057e10d2 Mon Sep 17 00:00:00 2001 From: DavidBert Date: Fri, 17 Oct 2025 15:21:24 +0000 Subject: [PATCH] use dispatch_attention_fn for multiple attention backend support --- scripts/convert_photon_to_diffusers.py | 7 +- src/diffusers/__init__.py | 2 + src/diffusers/models/__init__.py | 1 + .../models/transformers/transformer_photon.py | 100 ++++++++++-------- src/diffusers/pipelines/__init__.py | 1 + .../pipelines/photon/pipeline_photon.py | 11 +- .../pipelines/photon/test_pipeline_photon.py | 44 ++++---- 7 files changed, 85 insertions(+), 81 deletions(-) diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py index 6e4a49de37..c66bc31418 100644 --- a/scripts/convert_photon_to_diffusers.py +++ b/scripts/convert_photon_to_diffusers.py @@ -13,9 +13,6 @@ from typing import Dict, Tuple import torch from safetensors.torch import save_file - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) - from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel from diffusers.pipelines.photon import PhotonPipeline @@ -74,14 +71,12 @@ def create_parameter_mapping(depth: int) -> dict: mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight" # QK norm moved to attention module and renamed to match Attention's qk_norm structure - # Old: qk_norm.query_norm / qk_norm.key_norm -> New: norm_q / norm_k mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight" mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight" mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight" mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight" # K norm for text tokens moved to attention module - # Old: k_norm -> New: norm_added_k mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight" mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight" @@ -306,7 +301,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") parser.add_argument( - "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)" + "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )" ) parser.add_argument( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c2528bc50f..28b2ae2549 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -516,6 +516,7 @@ else: "MusicLDMPipeline", "OmniGenPipeline", "PaintByExamplePipeline", + "PhotonPipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", @@ -1180,6 +1181,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: MusicLDMPipeline, OmniGenPipeline, PaintByExamplePipeline, + PhotonPipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f3164e48cf..2151e602b2 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -191,6 +191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, + PhotonTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index 6c94e9f67a..1a40a82971 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import Tensor, nn @@ -21,7 +21,7 @@ from torch.nn.functional import fold, unfold from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention import AttentionMixin, AttentionModuleMixin -from ..attention_processor import Attention +from ..attention_dispatch import dispatch_attention_fn from ..embeddings import get_timestep_embedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -35,7 +35,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev r""" Generates 2D patch coordinate indices for a batch of images. - Parameters: + Args: batch_size (`int`): Number of images in the batch. height (`int`): @@ -63,7 +63,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: r""" Applies rotary positional embeddings (RoPE) to a query tensor. - Parameters: + Args: xq (`torch.Tensor`): Input tensor of shape `(..., dim)` representing the queries. freqs_cis (`torch.Tensor`): @@ -82,11 +82,12 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: class PhotonAttnProcessor2_0: r""" - Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with - diffusers Attention module while handling Photon-specific logic. + Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention + backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. """ _attention_backend = None + _parallel_config = None def __init__(self): if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): @@ -104,7 +105,7 @@ class PhotonAttnProcessor2_0: """ Apply Photon attention using PhotonAttention module. - Parameters: + Args: attn: PhotonAttention module containing projection layers hidden_states: Image tokens [B, L_img, D] encoder_hidden_states: Text tokens [B, L_txt, D] @@ -113,9 +114,7 @@ class PhotonAttnProcessor2_0: """ if encoder_hidden_states is None: - raise ValueError( - "PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens." - ) + raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") # Project image tokens to Q, K, V img_qkv = attn.img_qkv_proj(hidden_states) @@ -164,14 +163,24 @@ class PhotonAttnProcessor2_0: joint_mask = torch.cat([attention_mask, ones_img], dim=-1) attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1) - # Apply scaled dot-product attention - attn_output = torch.nn.functional.scaled_dot_product_attention( - img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor + # Apply attention using dispatch_attention_fn for backend support + # Reshape to match dispatch_attention_fn expectations: [B, L, H, D] + query = img_q.transpose(1, 2) # [B, L_img, H, D] + key = k.transpose(1, 2) # [B, L_txt + L_img, H, D] + value = v.transpose(1, 2) # [B, L_txt + L_img, H, D] + + attn_output = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask_tensor, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - # Reshape from [B, H, L_img, D] to [B, L_img, H*D] - batch_size, num_heads, seq_len, head_dim = attn_output.shape - attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) + # Reshape from [B, L_img, H, D] to [B, L_img, H*D] + batch_size, seq_len, num_heads, head_dim = attn_output.shape + attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim) # Apply output projection attn_output = attn.to_out[0](attn_output) @@ -183,8 +192,8 @@ class PhotonAttnProcessor2_0: class PhotonAttention(nn.Module, AttentionModuleMixin): r""" - Photon-style attention module that handles multi-source tokens and RoPE. - Similar to FluxAttention but adapted for Photon's architecture. + Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for + Photon's architecture. """ _default_processor_cls = PhotonAttnProcessor2_0 @@ -242,14 +251,14 @@ class PhotonAttention(nn.Module, AttentionModuleMixin): # inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py -class PhotoEmbedND(nn.Module): +class PhotonEmbedND(nn.Module): r""" N-dimensional rotary positional embedding. This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding dimension. The embeddings are combined and returned as a single tensor - Parameters: + Args: dim (int): Base embedding dimension (must be even). theta (int): @@ -258,7 +267,7 @@ class PhotoEmbedND(nn.Module): List of embedding dimensions for each axis (each must be even). """ - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: List[int]): super().__init__() self.dim = dim self.theta = theta @@ -288,7 +297,7 @@ class MLPEmbedder(nn.Module): r""" A simple 2-layer MLP used for embedding inputs. - Parameters: + Args: in_dim (`int`): Dimensionality of the input features. hidden_dim (`int`): @@ -316,7 +325,7 @@ class Modulation(nn.Module): Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into two tuples `(shift, scale, gate)`. - Parameters: + Args: dim (`int`): Dimensionality of the input vector. The output will have `6 * dim` features internally. @@ -340,7 +349,7 @@ class PhotonBlock(nn.Module): r""" Multimodal transformer block with text–image cross-attention, modulation, and MLP. - Parameters: + Args: hidden_size (`int`): Dimension of the hidden representations. num_heads (`int`): @@ -421,7 +430,7 @@ class PhotonBlock(nn.Module): r""" Runs modulation-gated cross-attention and MLP, with residual connections. - Parameters: + Args: hidden_states (`torch.Tensor`): Image tokens of shape `(B, L_img, hidden_size)`. encoder_hidden_states (`torch.Tensor`): @@ -468,7 +477,7 @@ class FinalLayer(nn.Module): This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level outputs. - Parameters: + Args: hidden_size (`int`): Dimensionality of the input tokens. patch_size (`int`): @@ -505,7 +514,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor: r""" Flattens an image tensor into a sequence of non-overlapping patches. - Parameters: + Args: img (`torch.Tensor`): Input image tensor of shape `(B, C, H, W)`. patch_size (`int`): @@ -523,7 +532,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: r""" Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). - Parameters: + Args: seq (`torch.Tensor`): Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // patch_size)`. @@ -550,7 +559,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): r""" Transformer-based 2D model for text to image generation. - Parameters: + Args: in_channels (`int`, *optional*, defaults to 16): Number of input channels in the latent image. patch_size (`int`, *optional*, defaults to 2): @@ -650,7 +659,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): self.hidden_size = hidden_size self.num_heads = num_heads - self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.txt_in = nn.Linear(context_in_dim, self.hidden_size) @@ -683,11 +692,10 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): def forward( self, - image_latent: Tensor, + hidden_states: Tensor, timestep: Tensor, - cross_attn_conditioning: Tensor, - micro_conditioning: Tensor, - cross_attn_mask: None | Tensor = None, + encoder_hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: @@ -697,16 +705,14 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. - Parameters: - image_latent (`torch.Tensor`): + Args: + hidden_states (`torch.Tensor`): Input latent image tensor of shape `(B, C, H, W)`. timestep (`torch.Tensor`): Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. - cross_attn_conditioning (`torch.Tensor`): + encoder_hidden_states (`torch.Tensor`): Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. - micro_conditioning (`torch.Tensor`): - Extra conditioning vector (currently unused, reserved for future use). - cross_attn_mask (`torch.Tensor`, *optional*): + attention_mask (`torch.Tensor`, *optional*): Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. attention_kwargs (`dict`, *optional*): Additional arguments passed to attention layers. @@ -719,15 +725,15 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. """ # Process text conditioning - txt = self.txt_in(cross_attn_conditioning) + txt = self.txt_in(encoder_hidden_states) # Convert image to sequence and embed - img = img2seq(image_latent, self.patch_size) + img = img2seq(hidden_states, self.patch_size) img = self.img_in(img) # Generate positional embeddings - bs, _, h, w = image_latent.shape - img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) + bs, _, h, w = hidden_states.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device) pe = self.pe_embedder(img_ids) # Compute time embedding @@ -742,7 +748,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): txt, vec, pe, - cross_attn_mask, + attention_mask, ) else: img = block( @@ -750,12 +756,12 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): encoder_hidden_states=txt, temb=vec, image_rotary_emb=pe, - attention_mask=cross_attn_mask, + attention_mask=attention_mask, ) # Final layer and convert back to image img = self.final_layer(img, vec) - output = seq2img(img, self.patch_size, image_latent.shape) + output = seq2img(img, self.patch_size, hidden_states.shape) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 1fa8dcf0c8..a44c92a834 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -718,6 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline + from .photon import PhotonPipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .qwenimage import ( diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index 457bbd2223..b394b12d83 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -206,11 +206,11 @@ EXAMPLE_DOC_STRING = """ >>> from diffusers import PhotonPipeline >>> # Load pipeline with from_pretrained - >>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint") + >>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft") >>> pipe.to("cuda") >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" - >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0] + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] >>> image.save("photon_output.png") ``` """ @@ -717,11 +717,10 @@ class PhotonPipeline( # Forward through transformer noise_pred = self.transformer( - image_latent=latents_in, + hidden_states=latents_in, timestep=t_cont, - cross_attn_conditioning=ca_embed, - micro_conditioning=None, - cross_attn_mask=ca_mask, + encoder_hidden_states=ca_embed, + attention_mask=ca_mask, return_dict=False, )[0] diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py index 9ac361c75b..9c5803b5d0 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -68,28 +68,28 @@ class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): tokenizer.model_max_length = 64 torch.manual_seed(0) - - encoder_params = dict( - vocab_size=tokenizer.vocab_size, - hidden_size=8, - intermediate_size=16, - num_hidden_layers=1, - num_attention_heads=2, - num_key_value_heads=1, - head_dim=4, - max_position_embeddings=64, - layer_types=["full_attention"], - attention_bias=False, - attention_dropout=0.0, - dropout_rate=0.0, - hidden_activation="gelu_pytorch_tanh", - rms_norm_eps=1e-06, - attn_logit_softcapping=50.0, - final_logit_softcapping=30.0, - query_pre_attn_scalar=4, - rope_theta=10000.0, - sliding_window=4096, - ) + + encoder_params = { + "vocab_size": tokenizer.vocab_size, + "hidden_size": 8, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "head_dim": 4, + "max_position_embeddings": 64, + "layer_types": ["full_attention"], + "attention_bias": False, + "attention_dropout": 0.0, + "dropout_rate": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + "rms_norm_eps": 1e-06, + "attn_logit_softcapping": 50.0, + "final_logit_softcapping": 30.0, + "query_pre_attn_scalar": 4, + "rope_theta": 10000.0, + "sliding_window": 4096, + } encoder_config = T5GemmaModuleConfig(**encoder_params) text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params) text_encoder = T5GemmaEncoder(text_encoder_config)