From fb2fbab10b836aa7f8fcd9a1e6c4f1ba0ebceff1 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 21 Sep 2022 10:57:01 +0200 Subject: [PATCH] Allow dtype to be specified in Flax pipeline (#600) * Fix typo in docstring. * Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. * Create latents in float32 The denoising loop always computes the next step in float32, so this would fail when using `bfloat16`. --- src/diffusers/configuration_utils.py | 5 ++++- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 2ab85ecee1..1c5c3d7afd 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -154,9 +154,12 @@ class ConfigMixin: """ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + model = cls(**init_dict) if return_unused_kwargs: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 7f068d7183..675b612662 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -30,7 +30,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`FlaxSchedulerMixin`]): + scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): @@ -157,7 +157,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): self.unet.sample_size, ) if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype) + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")