mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SDXL Flax] fix SDXL flax init (#5187)
* fix SDXL flax init * finish * Fix
This commit is contained in:
committed by
GitHub
parent
d9e7857af3
commit
c82f7bafba
@@ -134,8 +134,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
added_cond_kwargs = None
|
||||
if self.addition_embed_type == "text_time":
|
||||
# TODO: how to get this from the config? It's no longer cross_attention_dim
|
||||
text_embeds_dim = 1280
|
||||
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
|
||||
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
|
||||
is_refiner = (
|
||||
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
|
||||
== self.config.projection_class_embeddings_input_dim
|
||||
)
|
||||
num_micro_conditions = 5 if is_refiner else 6
|
||||
|
||||
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
|
||||
num_micro_conditions * self.config.addition_time_embed_dim
|
||||
)
|
||||
|
||||
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
|
||||
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
|
||||
added_cond_kwargs = {
|
||||
|
||||
@@ -215,14 +215,15 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * params["scheduler"].init_noise_sigma
|
||||
|
||||
# Prepare scheduler state
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||
)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * scheduler_state.init_noise_sigma
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# Denoising loop
|
||||
|
||||
Reference in New Issue
Block a user