diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index a511da1be3..cde2a594e6 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): added_cond_kwargs = None if self.addition_embed_type == "text_time": - # TODO: how to get this from the config? It's no longer cross_attention_dim - text_embeds_dim = 1280 + # we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner + # or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim` + is_refiner = ( + 5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim + == self.config.projection_class_embeddings_input_dim + ) + num_micro_conditions = 5 if is_refiner else 6 + + text_embeds_dim = self.config.projection_class_embeddings_input_dim - ( + num_micro_conditions * self.config.addition_time_embed_dim + ) + time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim time_ids_dims = time_ids_channels // self.addition_time_embed_dim added_cond_kwargs = { diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 9c038320c5..b8651f721c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -215,14 +215,15 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * params["scheduler"].init_noise_sigma # Prepare scheduler state scheduler_state = self.scheduler.set_timesteps( params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape ) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * scheduler_state.init_noise_sigma + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # Denoising loop