1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

working state (embeddings)

This commit is contained in:
Edna
2025-06-09 21:05:59 -06:00
committed by GitHub
parent e271af9495
commit 15f2bd5c39

View File

@@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -1327,7 +1327,7 @@ class Timesteps(nn.Module):
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
@@ -1401,7 +1401,7 @@ class ImagePositionalEmbeddings(nn.Module):
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
@@ -1637,6 +1637,35 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
return conditioning
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.register_buffer(
"mod_proj",
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
persistent=False,
)
def forward(
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
return input_vec
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
@@ -2230,6 +2259,25 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class ChromaApproximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList(
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
)
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
return self.out_proj(x)
class IPAdapterPlusImageProjectionBlock(nn.Module):
def __init__(
self,