mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
change to my own unpooled embeddeer
This commit is contained in:
@@ -1636,36 +1636,46 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
|
||||
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.embedder = ChromaApproximator(
|
||||
in_dim=factor * 4,
|
||||
out_dim=out_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
n_layers=n_layers,
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
self.register_buffer(
|
||||
"mod_proj",
|
||||
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
|
||||
get_timestep_embedding(torch.arange(344), 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
|
||||
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
mod_index_length = self.mod_proj.shape[0]
|
||||
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
|
||||
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
|
||||
|
||||
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
if guidance is not None:
|
||||
guidance_proj = self.guidance_proj(guidance.repeat(timesteps_proj.shape[0]))
|
||||
else:
|
||||
guidance_proj = torch.zeros(
|
||||
(1, self.guidance_proj.num_channels),
|
||||
dtype=timesteps_proj.dtype,
|
||||
device=timesteps_proj.device,
|
||||
)
|
||||
mod_proj = self.mod_proj.unsqueeze(0).repeat(timesteps_proj.shape[0], 1, 1).to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
|
||||
timestep_guidance = (
|
||||
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
||||
torch.cat([timesteps_proj, guidance_proj], dim=1).repeat(1, mod_index_length, 1)
|
||||
)
|
||||
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
|
||||
|
||||
input_vec = torch.cat([timestep_guidance, mod_proj], dim=-1)
|
||||
return input_vec
|
||||
|
||||
|
||||
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user