mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix nits.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -31,7 +31,7 @@ from ..normalization import RMSNorm
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor:
|
||||
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
|
||||
r"""
|
||||
Generates 2D patch coordinate indices for a batch of images.
|
||||
|
||||
@@ -59,7 +59,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev
|
||||
return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
|
||||
def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Applies rotary positional embeddings (RoPE) to a query tensor.
|
||||
|
||||
@@ -273,7 +273,7 @@ class PhotonEmbedND(nn.Module):
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
@@ -284,7 +284,7 @@ class PhotonEmbedND(nn.Module):
|
||||
out = out.reshape(*out.shape[:-1], 2, 2)
|
||||
return out.float()
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
@@ -314,7 +314,7 @@ class MLPEmbedder(nn.Module):
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
@@ -340,7 +340,7 @@ class Modulation(nn.Module):
|
||||
nn.init.constant_(self.lin.weight, 0)
|
||||
nn.init.constant_(self.lin.bias, 0)
|
||||
|
||||
def forward(self, vec: Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
|
||||
def forward(self, vec: torch.Tensor) -> Tuple[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
|
||||
return tuple(out[:3]), tuple(out[3:])
|
||||
|
||||
@@ -383,7 +383,7 @@ class PhotonBlock(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float | None = None,
|
||||
qk_scale: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -420,13 +420,13 @@ class PhotonBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
encoder_hidden_states: Tensor,
|
||||
temb: Tensor,
|
||||
image_rotary_emb: Tensor,
|
||||
attention_mask: Tensor | None = None,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
**kwargs: dict[str, Any],
|
||||
) -> Tensor:
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Runs modulation-gated cross-attention and MLP, with residual connections.
|
||||
|
||||
@@ -503,14 +503,14 @@ class FinalLayer(nn.Module):
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
||||
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def img2seq(img: Tensor, patch_size: int) -> Tensor:
|
||||
def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
r"""
|
||||
Flattens an image tensor into a sequence of non-overlapping patches.
|
||||
|
||||
@@ -528,7 +528,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor:
|
||||
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
|
||||
|
||||
|
||||
def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
|
||||
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
|
||||
|
||||
@@ -679,7 +679,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor:
|
||||
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return self.time_in(
|
||||
get_timestep_embedding(
|
||||
timesteps=timestep,
|
||||
@@ -692,9 +692,9 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Tensor,
|
||||
timestep: Tensor,
|
||||
encoder_hidden_states: Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
|
||||
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -17,7 +17,11 @@ from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
@pytest.mark.xfail(condition=is_transformers_version(">", "4.57.1"), reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544", strict=False)
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">", "4.57.1"),
|
||||
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
|
||||
strict=False,
|
||||
)
|
||||
class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PhotonPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
|
||||
Reference in New Issue
Block a user