From 1e216be8950d22813b975a4d026e58011cd78162 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 26 Jan 2023 14:37:19 +0100 Subject: [PATCH] make scaling factor a config arg of vae/vqvae (#1860) * make scaling factor cnfig arg of vae * fix * make flake happy * fix ldm * fix upscaler * qualirty * Apply suggestions from code review Co-authored-by: Anton Lozhkov Co-authored-by: Pedro Cuenca Co-authored-by: Patrick von Platen * solve conflicts, addres some comments * examples * examples min version * doc * fix type * typo * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py Co-authored-by: Pedro Cuenca * remove duplicate line * Apply suggestions from code review Co-authored-by: Patrick von Platen Co-authored-by: Anton Lozhkov Co-authored-by: Pedro Cuenca Co-authored-by: Patrick von Platen --- .../community/clip_guided_stable_diffusion.py | 4 ++-- examples/dreambooth/train_dreambooth.py | 2 +- examples/dreambooth/train_dreambooth_flax.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 2 +- .../colossalai/train_dreambooth_colossalai.py | 2 +- .../train_dreambooth_inpaint.py | 6 ++--- .../textual_inversion_bf16.py | 4 ++-- .../train_multi_subject_dreambooth.py | 4 ++-- examples/text_to_image/train_text_to_image.py | 2 +- .../text_to_image/train_text_to_image_flax.py | 2 +- .../text_to_image/train_text_to_image_lora.py | 2 +- .../textual_inversion/textual_inversion.py | 2 +- .../textual_inversion_flax.py | 2 +- src/diffusers/models/autoencoder_kl.py | 10 ++++++++- src/diffusers/models/vae_flax.py | 10 ++++++++- src/diffusers/models/vq_model.py | 8 +++++++ .../alt_diffusion/pipeline_alt_diffusion.py | 2 +- .../pipeline_alt_diffusion_img2img.py | 4 ++-- .../pipeline_audio_diffusion.py | 4 ++-- src/diffusers/pipelines/dit/pipeline_dit.py | 2 +- .../pipeline_latent_diffusion.py | 2 +- .../pipeline_paint_by_example.py | 4 ++-- .../pipeline_cycle_diffusion.py | 4 ++-- .../pipeline_flax_stable_diffusion.py | 2 +- .../pipeline_flax_stable_diffusion_img2img.py | 4 ++-- .../pipeline_flax_stable_diffusion_inpaint.py | 4 ++-- .../pipeline_stable_diffusion.py | 2 +- .../pipeline_stable_diffusion_depth2img.py | 4 ++-- ...peline_stable_diffusion_image_variation.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 4 ++-- .../pipeline_stable_diffusion_inpaint.py | 4 ++-- ...ipeline_stable_diffusion_inpaint_legacy.py | 4 ++-- ...eline_stable_diffusion_instruct_pix2pix.py | 2 +- .../pipeline_stable_diffusion_k_diffusion.py | 2 +- .../pipeline_stable_diffusion_upscale.py | 22 ++++++++++++++++--- .../pipeline_stable_diffusion_safe.py | 2 +- ...ipeline_versatile_diffusion_dual_guided.py | 2 +- ...ine_versatile_diffusion_image_variation.py | 2 +- ...eline_versatile_diffusion_text_to_image.py | 2 +- 39 files changed, 95 insertions(+), 55 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 71d9b65e50..07c44a9cbf 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -150,7 +150,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): else: raise ValueError(f"scheduler type {type(self.scheduler)} not supported") - sample = 1 / 0.18215 * sample + sample = 1 / self.vae.config.scaling_factor * sample image = self.vae.decode(sample).sample image = (image / 2 + 0.5).clamp(0, 1) @@ -336,7 +336,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a2201d4d63..009b5531fb 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -803,7 +803,7 @@ def main(args): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 247de851ce..1d80781024 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -533,7 +533,7 @@ def main(): latents = vae_outputs.latent_dist.sample(sample_rng) # (NHWC) -> (NCHW) latents = jnp.transpose(latents, (0, 3, 1, 2)) - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise_rng, timestep_rng = jax.random.split(sample_rng) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index d30cc1c969..599490b5a9 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -853,7 +853,7 @@ def main(args): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/examples/research_projects/colossalai/train_dreambooth_colossalai.py b/examples/research_projects/colossalai/train_dreambooth_colossalai.py index 40cd16f590..e5039f593d 100644 --- a/examples/research_projects/colossalai/train_dreambooth_colossalai.py +++ b/examples/research_projects/colossalai/train_dreambooth_colossalai.py @@ -607,7 +607,7 @@ def main(args): optimizer.zero_grad() latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py index 94adb2e1f7..fb21769064 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py @@ -33,7 +33,7 @@ from transformers import CLIPTextModel, CLIPTokenizer # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.10.0.dev0") +check_min_version("0.13.0.dev0") logger = get_logger(__name__) @@ -699,13 +699,13 @@ def main(): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Convert masked images to latent space masked_latents = vae.encode( batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype) ).latent_dist.sample() - masked_latents = masked_latents * 0.18215 + masked_latents = masked_latents * vae.config.scaling_factor masks = batch["masks"] # resize the mask to latents shape as we concatenate the mask to the latents diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index 966aa64802..f20db249ec 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -51,7 +51,7 @@ else: # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.10.0.dev0") +check_min_version("0.13.0.dev0") logger = get_logger(__name__) @@ -555,7 +555,7 @@ def main(): with accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 307cafebf3..6ec5bef550 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -31,7 +31,7 @@ from transformers import AutoTokenizer, PretrainedConfig # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.10.0.dev0") +check_min_version("0.13.0.dev0") logger = get_logger(__name__) @@ -788,7 +788,7 @@ def main(args): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4dc3caa131..7f6ddeaee1 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -636,7 +636,7 @@ def main(): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 901b212dc5..0850d39de2 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -438,7 +438,7 @@ def main(): latents = vae_outputs.latent_dist.sample(sample_rng) # (NHWC) -> (NCHW) latents = jnp.transpose(latents, (0, 3, 1, 2)) - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise_rng, timestep_rng = jax.random.split(sample_rng) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 324d40c2b9..d258c7a4cb 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -689,7 +689,7 @@ def main(): with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 148e2ade78..1c4bd235e6 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -711,7 +711,7 @@ def main(): with accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index bf00d4304d..b37d1e2ac4 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -525,7 +525,7 @@ def main(): latents = vae_outputs.latent_dist.sample(sample_rng) # (NHWC) -> (NCHW) latents = jnp.transpose(latents, (0, 3, 1, 2)) - latents = latents * 0.18215 + latents = latents * vae.config.scaling_factor noise_rng, timestep_rng = jax.random.split(sample_rng) noise = jax.random.normal(noise_rng, latents.shape) diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 1bf8662761..18e7eb39c8 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -54,8 +54,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin): block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): TODO + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ @register_to_config @@ -71,6 +78,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, + scaling_factor: float = 0.18215, ): super().__init__() diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 4533bb5551..70cbc3b892 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -752,8 +752,15 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): Latent space channels norm_num_groups (:obj:`int`, *optional*, defaults to `32`): Norm num group - sample_size (:obj:`int`, *optional*, defaults to `32`): + sample_size (:obj:`int`, *optional*, defaults to 32): Sample input size + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): parameters `dtype` """ @@ -767,6 +774,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): latent_channels: int = 4 norm_num_groups: int = 32 sample_size: int = 32 + scaling_factor: float = 0.18215 dtype: jnp.dtype = jnp.float32 def setup(self): diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 18fa80cb7b..0341798fc4 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -57,6 +57,13 @@ class VQModel(ModelMixin, ConfigMixin): sample_size (`int`, *optional*, defaults to `32`): TODO num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. """ @register_to_config @@ -74,6 +81,7 @@ class VQModel(ModelMixin, ConfigMixin): num_vq_embeddings: int = 256, norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, + scaling_factor: float = 0.18215, ): super().__init__() diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index d074c54634..a198136e78 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -369,7 +369,7 @@ class AltDiffusionPipeline(DiffusionPipeline): return image, has_nsfw_concept def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 39c314c43e..77a5f66780 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -391,7 +391,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): return image, has_nsfw_concept def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -490,7 +490,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): else: init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = 0.18215 * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size diff --git a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py index 14807e3e75..d814c16b08 100644 --- a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py +++ b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py @@ -153,7 +153,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample( generator=generator )[0] - input_images = 0.18215 * input_images + input_images = self.vqvae.config.scaling_factor * input_images if start_step > 0: images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1]) @@ -195,7 +195,7 @@ class AudioDiffusionPipeline(DiffusionPipeline): if self.vqvae is not None: # 0.18215 was scaling factor used in training to ensure unit variance - images = 1 / 0.18215 * images + images = 1 / self.vqvae.config.scaling_factor * images images = self.vqvae.decode(images)["sample"] images = (images / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index ea372036f9..9ab9a543d5 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -182,7 +182,7 @@ class DiTPipeline(DiffusionPipeline): else: latents = latent_model_input - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents samples = self.vae.decode(latents).sample samples = (samples / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 0a2448fd9d..35c993ff76 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -182,7 +182,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents + latents = 1 / self.vqvae.config.scaling_factor * latents image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index f24f3547c0..59e6004cbd 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -257,7 +257,7 @@ class PaintByExamplePipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -328,7 +328,7 @@ class PaintByExamplePipeline(DiffusionPipeline): masked_image_latents = torch.cat(masked_image_latents, dim=0) else: masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = 0.18215 * masked_image_latents + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 7300abbfe4..c592edb89d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -474,7 +474,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -509,7 +509,7 @@ class CycleDiffusionPipeline(DiffusionPipeline): else: init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = 0.18215 * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 307299a0f3..901b4cbdff 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -267,7 +267,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index 5214b6476f..f555b30b5a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -224,7 +224,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): # Create init_latents init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2)) - init_latents = 0.18215 * init_latents + init_latents = self.vae.config.scaling_factor * init_latents def loop_body(step, args): latents, scheduler_state = args @@ -272,7 +272,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index 5939595413..5a31309f1a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -259,7 +259,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): {"params": params["vae"]}, masked_image, method=self.vae.encode ).latent_dist masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2)) - masked_image_latents = 0.18215 * masked_image_latents + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents del mask_prng_seed mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest") @@ -327,7 +327,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): ) # scale and decode the image latents with vae - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 238b34e8fd..1be36cc21b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -366,7 +366,7 @@ class StableDiffusionPipeline(DiffusionPipeline): return image, has_nsfw_concept def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index e3211da986..de3b2b9e28 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -310,7 +310,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -413,7 +413,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): else: init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = 0.18215 * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 51ace48557..8ee662b851 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index a5e56b2caa..72682701f1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -400,7 +400,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -500,7 +500,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): else: init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = 0.18215 * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 263cc77375..ea7465182a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -466,7 +466,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -561,7 +561,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): masked_image_latents = torch.cat(masked_image_latents, dim=0) else: masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = 0.18215 * masked_image_latents + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 975e8240f0..6846c6f558 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -367,7 +367,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 @@ -450,7 +450,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): image = image.to(device=self.device, dtype=dtype) init_latent_dist = self.vae.encode(image).latent_dist init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents + init_latents = self.vae.config.scaling_factor * init_latents # Expand init_latents for batch_size and num_images_per_prompt init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 7923d3a73c..43b7911a19 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -588,7 +588,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 05229c5a1b..39f8e53f63 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -313,7 +313,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index f2bf6a8bdc..d723a4be8e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, logging, randn_tensor +from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -89,6 +89,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ): super().__init__() + # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + is_vae_scaling_factor_set_to_0_08333 = ( + hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 + ) + if not is_vae_scaling_factor_set_to_0_08333: + deprecation_message = ( + "The configuration file of the vae does not contain `scaling_factor` or it is set to" + f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" + " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to 0.08333" + " Please make sure to update the config accordingly, as not doing so might lead to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be" + " very nice if you could open a Pull Request for the `vae/config.json` file" + ) + deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) + vae.register_to_config(scaling_factor=0.08333) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -292,9 +308,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): extra_step_kwargs["generator"] = generator return extra_step_kwargs - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333 + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.08333 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index fbfc44a8c3..f23c2d36dd 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -364,7 +364,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 1d34937877..10356a90cc 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -330,7 +330,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 5acf6f3300..a0bdd6df90 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -190,7 +190,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 494b9f14f7..600eb953ea 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -247,7 +247,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16