diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py index c5809bc2c0..8bd6fdd5e5 100644 --- a/src/diffusers/models/transformers/transformer_photon.py +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ from ..normalization import RMSNorm logger = logging.get_logger(__name__) -def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor: +def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor: r""" Generates 2D patch coordinate indices for a batch of images. @@ -59,7 +59,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1) -def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: +def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: r""" Applies rotary positional embeddings (RoPE) to a query tensor. @@ -273,7 +273,7 @@ class PhotonEmbedND(nn.Module): self.theta = theta self.axes_dim = axes_dim - def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: + def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim omega = 1.0 / (theta**scale) @@ -284,7 +284,7 @@ class PhotonEmbedND(nn.Module): out = out.reshape(*out.shape[:-1], 2, 2) return out.float() - def forward(self, ids: Tensor) -> Tensor: + def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] emb = torch.cat( [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], @@ -314,7 +314,7 @@ class MLPEmbedder(nn.Module): self.silu = nn.SiLU() self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.out_layer(self.silu(self.in_layer(x))) @@ -340,7 +340,7 @@ class Modulation(nn.Module): nn.init.constant_(self.lin.weight, 0) nn.init.constant_(self.lin.bias, 0) - def forward(self, vec: Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]: + def forward(self, vec: torch.Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) return tuple(out[:3]), tuple(out[3:]) @@ -383,7 +383,7 @@ class PhotonBlock(nn.Module): hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, - qk_scale: float | None = None, + qk_scale: Optional[float] = None, ): super().__init__() @@ -420,13 +420,13 @@ class PhotonBlock(nn.Module): def forward( self, - hidden_states: Tensor, - encoder_hidden_states: Tensor, - temb: Tensor, - image_rotary_emb: Tensor, - attention_mask: Tensor | None = None, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: torch.Tensor, + attention_mask: Optional[Tensor] = None, **kwargs: dict[str, Any], - ) -> Tensor: + ) -> torch.Tensor: r""" Runs modulation-gated cross-attention and MLP, with residual connections. @@ -503,14 +503,14 @@ class FinalLayer(nn.Module): self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) - def forward(self, x: Tensor, vec: Tensor) -> Tensor: + def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = self.linear(x) return x -def img2seq(img: Tensor, patch_size: int) -> Tensor: +def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor: r""" Flattens an image tensor into a sequence of non-overlapping patches. @@ -528,7 +528,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor: return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) -def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: +def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor: r""" Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). @@ -679,7 +679,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): self.gradient_checkpointing = False - def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: return self.time_in( get_timestep_embedding( timesteps=timestep, @@ -692,9 +692,9 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): def forward( self, - hidden_states: Tensor, - timestep: Tensor, - encoder_hidden_states: Tensor, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, attention_mask: Optional[Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py index b394b12d83..4a10899ede 100644 --- a/src/diffusers/pipelines/photon/pipeline_photon.py +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved. +# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py index 0267ebeda2..c29c6ce0b0 100644 --- a/tests/pipelines/photon/test_pipeline_photon.py +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -17,7 +17,11 @@ from ..pipeline_params import TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin -@pytest.mark.xfail(condition=is_transformers_version(">", "4.57.1"), reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", strict=False) +@pytest.mark.xfail( + condition=is_transformers_version(">", "4.57.1"), + reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", + strict=False, +) class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = PhotonPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}