From 15f2bd5c3971f94475eacc01c3ac5ac802e32461 Mon Sep 17 00:00:00 2001 From: Edna <88869424+Ednaordinary@users.noreply.github.com> Date: Mon, 9 Jun 2025 21:05:59 -0600 Subject: [PATCH] working state (embeddings) --- src/diffusers/models/embeddings.py | 54 ++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c25e9997e3..8aa2ea5841 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -31,7 +31,7 @@ def get_timestep_embedding( downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, -): +) -> torch.Tensor: """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. @@ -1327,7 +1327,7 @@ class Timesteps(nn.Module): self.downscale_freq_shift = downscale_freq_shift self.scale = scale - def forward(self, timesteps): + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: t_emb = get_timestep_embedding( timesteps, self.num_channels, @@ -1401,7 +1401,7 @@ class ImagePositionalEmbeddings(nn.Module): Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the height and width of the latent space. - For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092 + For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For VQ-diffusion: @@ -1637,6 +1637,35 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): return conditioning +class CombinedTimestepTextProjChromaEmbeddings(nn.Module): + def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int): + super().__init__() + + self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0) + self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0) + + self.register_buffer( + "mod_proj", + get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ), + persistent=False, + ) + + def forward( + self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor + ) -> torch.Tensor: + mod_index_length = self.mod_proj.shape[0] + timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype) + guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device) + + mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device) + timestep_guidance = ( + torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1) + ) + input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1) + + return input_vec + + class CogView3CombinedTimestepSizeEmbeddings(nn.Module): def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): super().__init__() @@ -2230,6 +2259,25 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class ChromaApproximator(nn.Module): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5): + super().__init__() + self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True) + self.layers = nn.ModuleList( + [PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)] + ) + self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)]) + self.out_proj = nn.Linear(hidden_dim, out_dim) + + def forward(self, x): + x = self.in_proj(x) + + for layer, norms in zip(self.layers, self.norms): + x = x + layer(norms(x)) + + return self.out_proj(x) + + class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( self,