1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[SDXL Flax] fix SDXL flax init (#5187)

* fix SDXL flax init

* finish

* Fix
This commit is contained in:
Patrick von Platen
2023-09-26 19:55:05 +02:00
committed by GitHub
parent d9e7857af3
commit c82f7bafba
2 changed files with 15 additions and 4 deletions

View File

@@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# TODO: how to get this from the config? It's no longer cross_attention_dim
text_embeds_dim = 1280
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
is_refiner = (
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
== self.config.projection_class_embeddings_input_dim
)
num_micro_conditions = 5 if is_refiner else 6
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
num_micro_conditions * self.config.addition_time_embed_dim
)
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {

View File

@@ -215,14 +215,15 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * params["scheduler"].init_noise_sigma
# Prepare scheduler state
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * scheduler_state.init_noise_sigma
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# Denoising loop