mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Flax] time embedding (#1081)
* initial get_sinusoidal_embeddings * added asserts * better var name * fix docs
This commit is contained in:
@@ -17,23 +17,41 @@ import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
|
||||
# less general (only handles the case we currently need).
|
||||
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
|
||||
def get_sinusoidal_embeddings(
|
||||
timesteps: jnp.ndarray,
|
||||
embedding_dim: int,
|
||||
freq_shift: float = 1,
|
||||
min_timescale: float = 1,
|
||||
max_timescale: float = 1.0e4,
|
||||
flip_sin_to_cos: bool = False,
|
||||
scale: float = 1.0,
|
||||
) -> jnp.ndarray:
|
||||
"""Returns the positional encoding (same as Tensor2Tensor).
|
||||
Args:
|
||||
timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
embedding_dim: The number of output channels.
|
||||
min_timescale: The smallest time unit (should probably be 0.0).
|
||||
max_timescale: The largest time unit.
|
||||
Returns:
|
||||
a Tensor of timing signals [N, num_channels]
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
|
||||
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
|
||||
num_timescales = float(embedding_dim // 2)
|
||||
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
|
||||
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
|
||||
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
|
||||
|
||||
:param timesteps: a 1-D tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] tensor of positional embeddings.
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - freq_shift)
|
||||
emb = jnp.exp(jnp.arange(half_dim) * -emb)
|
||||
emb = timesteps[:, None] * emb[None, :]
|
||||
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
|
||||
return emb
|
||||
# scale embeddings
|
||||
scaled_time = scale * emb
|
||||
|
||||
if flip_sin_to_cos:
|
||||
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
|
||||
else:
|
||||
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
|
||||
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
|
||||
return signal
|
||||
|
||||
|
||||
class FlaxTimestepEmbedding(nn.Module):
|
||||
@@ -70,4 +88,4 @@ class FlaxTimesteps(nn.Module):
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)
|
||||
return get_sinusoidal_embeddings(timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift)
|
||||
|
||||
Reference in New Issue
Block a user