1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix nits.

This commit is contained in:
sayakpaul
2025-10-21 04:31:48 -10:00
parent 8de7b9247a
commit 53a2a7aff5
3 changed files with 27 additions and 23 deletions

View File

@@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ from ..normalization import RMSNorm
logger = logging.get_logger(__name__)
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor:
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
r"""
Generates 2D patch coordinate indices for a batch of images.
@@ -59,7 +59,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev
return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
r"""
Applies rotary positional embeddings (RoPE) to a query tensor.
@@ -273,7 +273,7 @@ class PhotonEmbedND(nn.Module):
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor:
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
@@ -284,7 +284,7 @@ class PhotonEmbedND(nn.Module):
out = out.reshape(*out.shape[:-1], 2, 2)
return out.float()
def forward(self, ids: Tensor) -> Tensor:
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
@@ -314,7 +314,7 @@ class MLPEmbedder(nn.Module):
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
@@ -340,7 +340,7 @@ class Modulation(nn.Module):
nn.init.constant_(self.lin.weight, 0)
nn.init.constant_(self.lin.bias, 0)
def forward(self, vec: Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
def forward(self, vec: torch.Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
return tuple(out[:3]), tuple(out[3:])
@@ -383,7 +383,7 @@ class PhotonBlock(nn.Module):
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
qk_scale: Optional[float] = None,
):
super().__init__()
@@ -420,13 +420,13 @@ class PhotonBlock(nn.Module):
def forward(
self,
hidden_states: Tensor,
encoder_hidden_states: Tensor,
temb: Tensor,
image_rotary_emb: Tensor,
attention_mask: Tensor | None = None,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: torch.Tensor,
attention_mask: Optional[Tensor] = None,
**kwargs: dict[str, Any],
) -> Tensor:
) -> torch.Tensor:
r"""
Runs modulation-gated cross-attention and MLP, with residual connections.
@@ -503,14 +503,14 @@ class FinalLayer(nn.Module):
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
def img2seq(img: Tensor, patch_size: int) -> Tensor:
def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
r"""
Flattens an image tensor into a sequence of non-overlapping patches.
@@ -528,7 +528,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor:
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
r"""
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
@@ -679,7 +679,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
self.gradient_checkpointing = False
def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor:
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return self.time_in(
get_timestep_embedding(
timesteps=timestep,
@@ -692,9 +692,9 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
def forward(
self,
hidden_states: Tensor,
timestep: Tensor,
encoder_hidden_states: Tensor,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,

View File

@@ -1,4 +1,4 @@
# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -17,7 +17,11 @@ from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@pytest.mark.xfail(condition=is_transformers_version(">", "4.57.1"), reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", strict=False)
@pytest.mark.xfail(
condition=is_transformers_version(">", "4.57.1"),
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PhotonPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}