diff --git a/src/diffusers/models/embeddings_flax.py b/src/diffusers/models/embeddings_flax.py index bf7d54b82e..fbb1dbab3c 100644 --- a/src/diffusers/models/embeddings_flax.py +++ b/src/diffusers/models/embeddings_flax.py @@ -84,10 +84,11 @@ class FlaxTimesteps(nn.Module): Time step embedding dimension """ dim: int = 32 + flip_sin_to_cos: bool = False freq_shift: float = 1 @nn.compact def __call__(self, timesteps): return get_sinusoidal_embeddings( - timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift, flip_sin_to_cos=True + timesteps, embedding_dim=self.dim, flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift ) diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 8a33853700..3a3f1d9e14 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -85,6 +85,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): The dimension of the cross attention features. dropout (`float`, *optional*, defaults to 0): Dropout probability for down, up and bottleneck blocks. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + """ sample_size: int = 32 @@ -105,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): dropout: float = 0.0 use_linear_projection: bool = False dtype: jnp.dtype = jnp.float32 + flip_sin_to_cos: bool = True freq_shift: int = 0 def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: @@ -133,7 +138,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ) # time - self.time_proj = FlaxTimesteps(block_out_channels[0], freq_shift=self.config.freq_shift) + self.time_proj = FlaxTimesteps( + block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.config.freq_shift + ) self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype) only_cross_attention = self.only_cross_attention