mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
working state (embeddings)
This commit is contained in:
@@ -31,7 +31,7 @@ def get_timestep_embedding(
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
@@ -1327,7 +1327,7 @@ class Timesteps(nn.Module):
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, timesteps):
|
||||
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
@@ -1401,7 +1401,7 @@ class ImagePositionalEmbeddings(nn.Module):
|
||||
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
||||
height and width of the latent space.
|
||||
|
||||
For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
|
||||
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
||||
|
||||
For VQ-diffusion:
|
||||
|
||||
@@ -1637,6 +1637,35 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
||||
return conditioning
|
||||
|
||||
|
||||
class CombinedTimestepTextProjChromaEmbeddings(nn.Module):
|
||||
def __init__(self, factor: int, hidden_dim: int, out_dim: int, n_layers: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.guidance_proj = Timesteps(num_channels=factor, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
|
||||
self.register_buffer(
|
||||
"mod_proj",
|
||||
get_timestep_embedding(torch.arange(out_dim)*1000, 2 * factor, flip_sin_to_cos=True, downscale_freq_shift=0, ),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, timestep: torch.Tensor, guidance: Optional[torch.Tensor], pooled_projections: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
mod_index_length = self.mod_proj.shape[0]
|
||||
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
|
||||
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
|
||||
|
||||
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
|
||||
timestep_guidance = (
|
||||
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
|
||||
)
|
||||
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
|
||||
|
||||
return input_vec
|
||||
|
||||
|
||||
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
||||
super().__init__()
|
||||
@@ -2230,6 +2259,25 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ChromaApproximator(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList(
|
||||
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
|
||||
)
|
||||
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
|
||||
self.out_proj = nn.Linear(hidden_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.in_proj(x)
|
||||
|
||||
for layer, norms in zip(self.layers, self.norms):
|
||||
x = x + layer(norms(x))
|
||||
|
||||
return self.out_proj(x)
|
||||
|
||||
|
||||
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user