mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
make scaling factor a config arg of vae/vqvae (#1860)
* make scaling factor cnfig arg of vae * fix * make flake happy * fix ldm * fix upscaler * qualirty * Apply suggestions from code review Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * solve conflicts, addres some comments * examples * examples min version * doc * fix type * typo * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * remove duplicate line * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <anton@huggingface.co> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -150,7 +150,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
else:
|
||||
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
|
||||
|
||||
sample = 1 / 0.18215 * sample
|
||||
sample = 1 / self.vae.config.scaling_factor * sample
|
||||
image = self.vae.decode(sample).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@@ -336,7 +336,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@@ -803,7 +803,7 @@ def main(args):
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -533,7 +533,7 @@ def main():
|
||||
latents = vae_outputs.latent_dist.sample(sample_rng)
|
||||
# (NHWC) -> (NCHW)
|
||||
latents = jnp.transpose(latents, (0, 3, 1, 2))
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise_rng, timestep_rng = jax.random.split(sample_rng)
|
||||
|
||||
@@ -853,7 +853,7 @@ def main(args):
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -607,7 +607,7 @@ def main(args):
|
||||
optimizer.zero_grad()
|
||||
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -33,7 +33,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -699,13 +699,13 @@ def main():
|
||||
# Convert images to latent space
|
||||
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Convert masked images to latent space
|
||||
masked_latents = vae.encode(
|
||||
batch["masked_images"].reshape(batch["pixel_values"].shape).to(dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
masked_latents = masked_latents * 0.18215
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
|
||||
masks = batch["masks"]
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
|
||||
@@ -51,7 +51,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -555,7 +555,7 @@ def main():
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
|
||||
@@ -31,7 +31,7 @@ from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
check_min_version("0.13.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -788,7 +788,7 @@ def main(args):
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -636,7 +636,7 @@ def main():
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -438,7 +438,7 @@ def main():
|
||||
latents = vae_outputs.latent_dist.sample(sample_rng)
|
||||
# (NHWC) -> (NCHW)
|
||||
latents = jnp.transpose(latents, (0, 3, 1, 2))
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise_rng, timestep_rng = jax.random.split(sample_rng)
|
||||
|
||||
@@ -689,7 +689,7 @@ def main():
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -711,7 +711,7 @@ def main():
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -525,7 +525,7 @@ def main():
|
||||
latents = vae_outputs.latent_dist.sample(sample_rng)
|
||||
# (NHWC) -> (NCHW)
|
||||
latents = jnp.transpose(latents, (0, 3, 1, 2))
|
||||
latents = latents * 0.18215
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
noise_rng, timestep_rng = jax.random.split(sample_rng)
|
||||
noise = jax.random.normal(noise_rng, latents.shape)
|
||||
|
||||
@@ -54,8 +54,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
obj:`(64,)`): Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -71,6 +78,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -752,8 +752,15 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
Latent space channels
|
||||
norm_num_groups (:obj:`int`, *optional*, defaults to `32`):
|
||||
Norm num group
|
||||
sample_size (:obj:`int`, *optional*, defaults to `32`):
|
||||
sample_size (:obj:`int`, *optional*, defaults to 32):
|
||||
Sample input size
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
parameters `dtype`
|
||||
"""
|
||||
@@ -767,6 +774,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
latent_channels: int = 4
|
||||
norm_num_groups: int = 32
|
||||
sample_size: int = 32
|
||||
scaling_factor: float = 0.18215
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
|
||||
@@ -57,6 +57,13 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
||||
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
||||
scaling_factor (`float`, *optional*, defaults to `0.18215`):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -74,6 +81,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
num_vq_embeddings: int = 256,
|
||||
norm_num_groups: int = 32,
|
||||
vq_embed_dim: Optional[int] = None,
|
||||
scaling_factor: float = 0.18215,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -369,7 +369,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -391,7 +391,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
@@ -490,7 +490,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
|
||||
@@ -153,7 +153,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
|
||||
generator=generator
|
||||
)[0]
|
||||
input_images = 0.18215 * input_images
|
||||
input_images = self.vqvae.config.scaling_factor * input_images
|
||||
|
||||
if start_step > 0:
|
||||
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
|
||||
@@ -195,7 +195,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
if self.vqvae is not None:
|
||||
# 0.18215 was scaling factor used in training to ensure unit variance
|
||||
images = 1 / 0.18215 * images
|
||||
images = 1 / self.vqvae.config.scaling_factor * images
|
||||
images = self.vqvae.decode(images)["sample"]
|
||||
|
||||
images = (images / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@@ -182,7 +182,7 @@ class DiTPipeline(DiffusionPipeline):
|
||||
else:
|
||||
latents = latent_model_input
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
samples = self.vae.decode(latents).sample
|
||||
|
||||
samples = (samples / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@@ -182,7 +182,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vqvae.config.scaling_factor * latents
|
||||
image = self.vqvae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
@@ -257,7 +257,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
@@ -328,7 +328,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
|
||||
@@ -474,7 +474,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
@@ -509,7 +509,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
|
||||
@@ -267,7 +267,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
|
||||
@@ -224,7 +224,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
# Create init_latents
|
||||
init_latent_dist = self.vae.apply({"params": params["vae"]}, image, method=self.vae.encode).latent_dist
|
||||
init_latents = init_latent_dist.sample(key=prng_seed).transpose((0, 3, 1, 2))
|
||||
init_latents = 0.18215 * init_latents
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
def loop_body(step, args):
|
||||
latents, scheduler_state = args
|
||||
@@ -272,7 +272,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
latents, _ = jax.lax.fori_loop(start_timestep, num_inference_steps, loop_body, (latents, scheduler_state))
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
|
||||
@@ -259,7 +259,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
|
||||
{"params": params["vae"]}, masked_image, method=self.vae.encode
|
||||
).latent_dist
|
||||
masked_image_latents = masked_image_latent_dist.sample(key=mask_prng_seed).transpose((0, 3, 1, 2))
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
||||
del mask_prng_seed
|
||||
|
||||
mask = jax.image.resize(mask, (*mask.shape[:-2], *masked_image_latents.shape[-2:]), method="nearest")
|
||||
@@ -327,7 +327,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
|
||||
)
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
|
||||
@@ -366,7 +366,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
return image, has_nsfw_concept
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -310,7 +310,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
@@ -413,7 +413,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
|
||||
@@ -195,7 +195,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -400,7 +400,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
@@ -500,7 +500,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
else:
|
||||
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
|
||||
@@ -466,7 +466,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
@@ -561,7 +561,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = 0.18215 * masked_image_latents
|
||||
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
|
||||
@@ -367,7 +367,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
@@ -450,7 +450,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
image = image.to(device=self.device, dtype=dtype)
|
||||
init_latent_dist = self.vae.encode(image).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
init_latents = self.vae.config.scaling_factor * init_latents
|
||||
|
||||
# Expand init_latents for batch_size and num_images_per_prompt
|
||||
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
|
||||
|
||||
@@ -588,7 +588,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -313,7 +313,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
@@ -89,6 +89,22 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
|
||||
is_vae_scaling_factor_set_to_0_08333 = (
|
||||
hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
|
||||
)
|
||||
if not is_vae_scaling_factor_set_to_0_08333:
|
||||
deprecation_message = (
|
||||
"The configuration file of the vae does not contain `scaling_factor` or it is set to"
|
||||
f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
|
||||
" version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to 0.08333"
|
||||
" Please make sure to update the config accordingly, as not doing so might lead to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be"
|
||||
" very nice if you could open a Pull Request for the `vae/config.json` file"
|
||||
)
|
||||
deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
|
||||
vae.register_to_config(scaling_factor=0.08333)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -292,9 +308,9 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents with 0.18215->0.08333
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.08333 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -364,7 +364,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -330,7 +330,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -190,7 +190,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -247,7 +247,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
Reference in New Issue
Block a user