1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make scaling factor a config arg of vae/vqvae (#1860)

* make scaling factor cnfig arg of vae

* fix

* make flake happy

* fix ldm

* fix upscaler

* qualirty

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* solve conflicts, addres some comments

* examples

* examples min version

* doc

* fix type

* typo

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* remove duplicate line

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2023-01-26 14:37:19 +01:00
committed by GitHub
parent 915a563611
commit 1e216be895
39 changed files with 95 additions and 55 deletions

View File

@@ -150,7 +150,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
else:
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
sample = 1 / 0.18215 * sample
sample = 1 / self.vae.config.scaling_factor * sample
image = self.vae.decode(sample).sample
image = (image / 2 + 0.5).clamp(0, 1)
@@ -336,7 +336,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)

View File

@@ -803,7 +803,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -533,7 +533,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng)

View File

@@ -853,7 +853,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -607,7 +607,7 @@ def main(args):
optimizer.zero_grad()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -33,7 +33,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
@@ -699,13 +699,13 @@ def main():
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Convert masked images to latent space
masked_latents = vae.encode(
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
).latent_dist.sample()
masked_latents = masked_latents * 0.18215
masked_latents = masked_latents * vae.config.scaling_factor
masks = batch["masks"]
# resize the mask to latents shape as we concatenate the mask to the latents

View File

@@ -51,7 +51,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
@@ -555,7 +555,7 @@ def main():
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)

View File

@@ -31,7 +31,7 @@ from transformers import AutoTokenizer, PretrainedConfig
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
check_min_version("0.13.0.dev0")
logger = get_logger(__name__)
@@ -788,7 +788,7 @@ def main(args):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -636,7 +636,7 @@ def main():
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -438,7 +438,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise_rng, timestep_rng = jax.random.split(sample_rng)

View File

@@ -689,7 +689,7 @@ def main():
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -711,7 +711,7 @@ def main():
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

View File

@@ -525,7 +525,7 @@ def main():
latents = vae_outputs.latent_dist.sample(sample_rng)
# (NHWC) -> (NCHW)
latents = jnp.transpose(latents, (0, 3, 1, 2))
latents = latents * 0.18215
latents = latents * vae.config.scaling_factor
noise_rng, timestep_rng = jax.random.split(sample_rng)
noise = jax.random.normal(noise_rng, latents.shape)

View File

@@ -54,8 +54,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
@register_to_config
@@ -71,6 +78,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
):
super().__init__()

View File

@@ -752,8 +752,15 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
Latent space channels
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
Norm num group
sample_size (:obj:`int`, *optional*, defaults to `32`):
sample_size (:obj:`int`, *optional*, defaults to 32):
Sample input size
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
parameters `dtype`
"""
@@ -767,6 +774,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
latent_channels: int = 4
norm_num_groups: int = 32
sample_size: int = 32
scaling_factor: float = 0.18215
dtype: jnp.dtype = jnp.float32
def setup(self):

View File

@@ -57,6 +57,13 @@ class VQModel(ModelMixin, ConfigMixin):
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
scaling_factor (`float`, *optional*, defaults to `0.18215`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
@register_to_config
@@ -74,6 +81,7 @@ class VQModel(ModelMixin, ConfigMixin):
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
scaling_factor: float = 0.18215,
):
super().__init__()

View File

@@ -369,7 +369,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -391,7 +391,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -490,7 +490,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size

View File

@@ -153,7 +153,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
generator=generator
)[0]
input_images = 0.18215 * input_images
input_images = self.vqvae.config.scaling_factor * input_images
if start_step > 0:
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
@@ -195,7 +195,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
if self.vqvae is not None:
# 0.18215 was scaling factor used in training to ensure unit variance
images = 1 / 0.18215 * images
images = 1 / self.vqvae.config.scaling_factor * images
images = self.vqvae.decode(images)["sample"]
images = (images / 2 + 0.5).clamp(0, 1)

View File

@@ -182,7 +182,7 @@ class DiTPipeline(DiffusionPipeline):
else:
latents = latent_model_input
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
samples = self.vae.decode(latents).sample
samples = (samples / 2 + 0.5).clamp(0, 1)

View File

@@ -182,7 +182,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vqvae.config.scaling_factor * latents
image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)

View File

@@ -257,7 +257,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -328,7 +328,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:

View File

@@ -474,7 +474,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -509,7 +509,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size

View File

@@ -267,7 +267,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)

View File

@@ -224,7 +224,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
# Create init_latents
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
def loop_body(step, args):
latents, scheduler_state = args
@@ -272,7 +272,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)

View File

@@ -259,7 +259,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
{"params": params["vae"]}, masked_image, method=self.vae.encode
).latent_dist
masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2))
masked_image_latents = 0.18215 * masked_image_latents
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
del mask_prng_seed
mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest")
@@ -327,7 +327,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)

View File

@@ -366,7 +366,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
return image, has_nsfw_concept
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -310,7 +310,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -413,7 +413,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size

View File

@@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -400,7 +400,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -500,7 +500,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size

View File

@@ -466,7 +466,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -561,7 +561,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = 0.18215 * masked_image_latents
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:

View File

@@ -367,7 +367,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
@@ -450,7 +450,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
image = image.to(device=self.device, dtype=dtype)
init_latent_dist = self.vae.encode(image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
init_latents = self.vae.config.scaling_factor * init_latents
# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)

View File

@@ -588,7 +588,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -313,7 +313,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -89,6 +89,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
):
super().__init__()
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
is_vae_scaling_factor_set_to_0_08333 = (
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
)
if not is_vae_scaling_factor_set_to_0_08333:
deprecation_message = (
"The configuration file of the vae does not contain `scaling_factor` or it is set to"
f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
" version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to 0.08333"
" Please make sure to update the config accordingly, as not doing so might lead to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be"
" very nice if you could open a Pull Request for the `vae/config.json` file"
)
deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
vae.register_to_config(scaling_factor=0.08333)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -292,9 +308,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.08333 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -364,7 +364,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -330,7 +330,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -190,7 +190,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16

View File

@@ -247,7 +247,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / 0.18215 * latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16