mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Flax flip_sin_to_cos (#1369)
* Fix Flax flip_sin_to_cos * Adding flip_sin_to_cos Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
This commit is contained in:
@@ -84,10 +84,11 @@ class FlaxTimesteps(nn.Module):
|
||||
Time step embedding dimension
|
||||
"""
|
||||
dim: int = 32
|
||||
flip_sin_to_cos: bool = False
|
||||
freq_shift: float = 1
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, timesteps):
|
||||
return get_sinusoidal_embeddings(
|
||||
timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True
|
||||
timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
|
||||
)
|
||||
|
||||
@@ -85,6 +85,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
The dimension of the cross attention features.
|
||||
dropout (`float`, *optional*, defaults to 0):
|
||||
Dropout probability for down, up and bottleneck blocks.
|
||||
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
|
||||
"""
|
||||
|
||||
sample_size: int = 32
|
||||
@@ -105,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dropout: float = 0.0
|
||||
use_linear_projection: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
|
||||
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
|
||||
@@ -133,7 +138,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
# time
|
||||
self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift)
|
||||
self.time_proj = FlaxTimesteps(
|
||||
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift
|
||||
)
|
||||
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
||||
|
||||
only_cross_attention = self.only_cross_attention
|
||||
|
||||
Reference in New Issue
Block a user