From 77b429eda416f0f6645b591b370971913f6bdbf5 Mon Sep 17 00:00:00 2001 From: Edna <88869424+Ednaordinary@users.noreply.github.com> Date: Wed, 11 Jun 2025 19:35:10 -0600 Subject: [PATCH] change to my own unpooled embeddeer --- src/diffusers/models/embeddings.py | 32 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8aa2ea5841..0ba64eadf2 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1636,36 +1636,46 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): return conditioning - class CombinedTimestepTextProjChromaEmbeddings(nn.Module): def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int): super().__init__() self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0) self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0) + self.embedder = ChromaApproximator( + in_dim=factor * 4, + out_dim=out_dim, + hidden_dim=hidden_dim, + n_layers=n_layers, + ) + self.embedding_dim = embedding_dim self.register_buffer( "mod_proj", - get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ), + get_timestep_embedding(torch.arange(344), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0), persistent=False, ) def forward( - self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor + self, timestep: torch.Tensor, guidance: Optional[torch.Tensor] ) -> torch.Tensor: mod_index_length = self.mod_proj.shape[0] - timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype) - guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device) - - mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device) + timesteps_proj = self.time_proj(timestep) + if guidance is not None: + guidance_proj = self.guidance_proj(guidance.repeat(timesteps_proj.shape[0])) + else: + guidance_proj = torch.zeros( + (1, self.guidance_proj.num_channels), + dtype=timesteps_proj.dtype, + device=timesteps_proj.device, + ) + mod_proj = self.mod_proj.unsqueeze(0).repeat(timesteps_proj.shape[0], 1, 1).to(dtype=timesteps_proj.dtype, device=timesteps_proj.device) timestep_guidance = ( - torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1) + torch.cat([timesteps_proj, guidance_proj], dim=1).repeat(1, mod_index_length, 1) ) - input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1) - + input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1) return input_vec - class CogView3CombinedTimestepSizeEmbeddings(nn.Module): def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): super().__init__()