1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

change to my own unpooled embeddeer

This commit is contained in:
Edna
2025-06-11 19:35:10 -06:00
committed by GitHub
parent 3309ffef1c
commit 77b429eda4

View File

@@ -1636,36 +1636,46 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
return conditioning
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
super().__init__()
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
self.embedder = ChromaApproximator(
in_dim=factor * 4,
out_dim=out_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
)
self.embedding_dim = embedding_dim
self.register_buffer(
"mod_proj",
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
get_timestep_embedding(torch.arange(344), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
persistent=False,
)
def forward(
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor]
) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timesteps_proj = self.time_proj(timestep)
if guidance is not None:
guidance_proj = self.guidance_proj(guidance.repeat(timesteps_proj.shape[0]))
else:
guidance_proj = torch.zeros(
(1, self.guidance_proj.num_channels),
dtype=timesteps_proj.dtype,
device=timesteps_proj.device,
)
mod_proj = self.mod_proj.unsqueeze(0).repeat(timesteps_proj.shape[0], 1, 1).to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
torch.cat([timesteps_proj, guidance_proj], dim=1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
return input_vec
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()