1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

consistency decoder (#5694)

* consistency decoder

* rename

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

* uP

* Apply suggestions from code review

* uP

* uP

* uP

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Will Berman
2023-11-09 03:21:41 -08:00
committed by GitHub
parent 43346adc1f
commit 2fd46405cd
24 changed files with 2038 additions and 38 deletions

View File

@@ -200,6 +200,8 @@
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_tiny
title: Tiny AutoEncoder
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/transformer2d
title: Transformer2D
- local: api/models/transformer_temporal
@@ -344,6 +346,8 @@
title: Overview
- local: api/schedulers/cm_stochastic_iterative
title: CMStochasticIterativeScheduler
- local: api/schedulers/consistency_decoder
title: ConsistencyDecoderScheduler
- local: api/schedulers/ddim_inverse
title: DDIMInverseScheduler
- local: api/schedulers/ddim

View File

@@ -0,0 +1,18 @@
# Consistency Decoder
Consistency decoder can be used to decode the latents from the denoising UNet in the [`StableDiffusionPipeline`]. This decoder was introduced in the [DALL-E 3 technical report](https://openai.com/dall-e-3).
The original codebase can be found at [openai/consistencydecoder](https://github.com/openai/consistencydecoder).
<Tip warning={true}>
Inference is only supported for 2 iterations as of now.
</Tip>
The pipeline could not have been contributed without the help of [madebyollin](https://github.com/madebyollin) and [mrsteyk](https://github.com/mrsteyk) from [this issue](https://github.com/openai/consistencydecoder/issues/1).
## ConsistencyDecoderVAE
[[autodoc]] ConsistencyDecoderVAE
- all
- decode

View File

@@ -0,0 +1,9 @@
# ConsistencyDecoderScheduler
This scheduler is a part of the [`ConsistencyDecoderPipeline`] and was introduced in [DALL-E 3](https://openai.com/dall-e-3).
The original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models).
## ConsistencyDecoderScheduler
[[autodoc]] schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler

File diff suppressed because it is too large Load Diff

View File

@@ -77,6 +77,7 @@ else:
"AsymmetricAutoencoderKL",
"AutoencoderKL",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ModelMixin",
"MotionAdapter",
@@ -443,6 +444,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
ModelMixin,
MotionAdapter,

View File

@@ -24,6 +24,7 @@ if is_torch_available():
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
@@ -50,6 +51,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .modeling_utils import ModelMixin

View File

@@ -294,7 +294,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.

View File

@@ -0,0 +1,365 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..schedulers import ConsistencyDecoderScheduler
from ..utils import BaseOutput
from ..utils.accelerate_utils import apply_forward_hook
from ..utils.torch_utils import randn_tensor
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_utils import ModelMixin
from .unet_2d import UNet2DModel
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
class ConsistencyDecoderVAEOutput(BaseOutput):
"""
Output of encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
r"""
The consistency decoder used with DALL-E 3.
Examples:
```py
>>> import torch
>>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype)
>>> pipe = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
... ).to("cuda")
>>> pipe("horse", generator=torch.manual_seed(0)).images
```
"""
@register_to_config
def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels):
super().__init__()
self.encoder = Encoder(**encoder_args)
self.decoder_unet = UNet2DModel(**decoder_args)
self.decoder_scheduler = ConsistencyDecoderScheduler()
self.register_buffer(
"means",
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
persistent=False,
)
self.register_buffer(
"stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
num_inference_steps=2,
) -> Union[DecoderOutput, torch.FloatTensor]:
z = (z * self.config.scaling_factor - self.means) / self.stds
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
batch_size, _, height, width = z.shape
self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
(batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
)
for t in self.decoder_scheduler.timesteps:
model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
x_t = prev_sample
x_0 = x_t
if not return_dict:
return (x_0,)
return DecoderOutput(sample=x_0)
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
def blend_v(self, a, b, blend_extent):
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
def blend_h(self, a, b, blend_extent):
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
plain tuple.
Returns:
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
otherwise a plain `tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, generator=generator).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -117,6 +117,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
num_train_timesteps: Optional[int] = None,
):
super().__init__()
@@ -144,6 +145,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

View File

@@ -852,7 +852,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -893,7 +893,9 @@ class AltDiffusionImg2ImgPipeline(
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -6,7 +6,9 @@ from ...utils import (
)
_import_structure = {"pipeline_consistency_models": ["ConsistencyModelPipeline"]}
_import_structure = {
"pipeline_consistency_models": ["ConsistencyModelPipeline"],
}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_consistency_models import ConsistencyModelPipeline

View File

@@ -1,3 +1,17 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Union
import torch

View File

@@ -1058,7 +1058,9 @@ class StableDiffusionControlNetPipeline(
torch.cuda.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -1138,7 +1138,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
torch.cuda.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -1405,7 +1405,9 @@ class StableDiffusionControlNetInpaintPipeline(
torch.cuda.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -838,7 +838,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -885,7 +885,9 @@ class StableDiffusionImg2ImgPipeline(
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -1159,7 +1159,9 @@ class StableDiffusionInpaintPipeline(
init_image = self._encode_vae_image(init_image, generator=generator)
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -38,6 +38,7 @@ except OptionalDependencyNotAvailable:
_dummy_modules.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
@@ -128,6 +129,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler

View File

@@ -0,0 +1,180 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import SchedulerMixin
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(
num_diffusion_timesteps,
max_beta=0.999,
alpha_transform_type="cosine",
):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
Choose from `cosine` or `exp`
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
if alpha_transform_type == "cosine":
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
return math.exp(t * -12.0)
else:
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
@dataclass
class ConsistencyDecoderSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1024,
sigma_data: float = 0.5,
):
betas = betas_for_alpha_bar(num_train_timesteps)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
):
if num_inference_steps != 2:
raise ValueError("Currently more than 2 inference steps are not supported.")
self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device)
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
self.c_skip = self.c_skip.to(device)
self.c_out = self.c_out.to(device)
self.c_in = self.c_in.to(device)
@property
def init_noise_sigma(self):
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.FloatTensor`:
A scaled input sample.
"""
return sample * self.c_in[timestep]
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from the learned diffusion model.
timestep (`float`):
The current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
a tuple is returned where the first element is the sample tensor.
"""
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
timestep_idx = torch.where(self.timesteps == timestep)[0]
if timestep_idx == len(self.timesteps) - 1:
prev_sample = x_0
else:
noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device)
prev_sample = (
self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0
+ self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise
)
if not return_dict:
return (prev_sample,)
return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample)

View File

@@ -47,6 +47,21 @@ class AutoencoderTiny(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ConsistencyDecoderVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -196,11 +196,15 @@ class UNetTesterMixin:
class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3
forward_requires_fresh_args = False
def test_from_save_pretrained(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
@@ -214,11 +218,18 @@ class ModelTesterMixin:
new_model.to(torch_device)
with torch.no_grad():
image = model(**inputs_dict)
if self.forward_requires_fresh_args:
image = model(**self.inputs_dict(0))
else:
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**inputs_dict)
if self.forward_requires_fresh_args:
new_image = new_model(**self.inputs_dict(0))
else:
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
@@ -275,8 +286,11 @@ class ModelTesterMixin:
)
def test_set_xformers_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
@@ -286,17 +300,26 @@ class ModelTesterMixin:
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output = model(**self.inputs_dict(0))[0]
else:
output = model(**inputs_dict)[0]
model.enable_xformers_memory_efficient_attention()
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_2 = model(**self.inputs_dict(0))[0]
else:
output_2 = model(**inputs_dict)[0]
model.set_attn_processor(XFormersAttnProcessor())
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_3 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_3 = model(**self.inputs_dict(0))[0]
else:
output_3 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
@@ -307,8 +330,12 @@ class ModelTesterMixin:
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
@@ -317,22 +344,34 @@ class ModelTesterMixin:
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
output_1 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_1 = model(**self.inputs_dict(0))[0]
else:
output_1 = model(**inputs_dict)[0]
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_2 = model(**self.inputs_dict(0))[0]
else:
output_2 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
output_4 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_4 = model(**self.inputs_dict(0))[0]
else:
output_4 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor())
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
output_5 = model(**inputs_dict)[0]
if self.forward_requires_fresh_args:
output_5 = model(**self.inputs_dict(0))[0]
else:
output_5 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
@@ -342,9 +381,12 @@ class ModelTesterMixin:
assert torch.allclose(output_2, output_5, atol=self.base_precision)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
@@ -367,11 +409,17 @@ class ModelTesterMixin:
new_model.to(torch_device)
with torch.no_grad():
image = model(**inputs_dict)
if self.forward_requires_fresh_args:
image = model(**self.inputs_dict(0))
else:
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
new_image = new_model(**inputs_dict)
if self.forward_requires_fresh_args:
new_image = new_model(**self.inputs_dict(0))
else:
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
@@ -405,17 +453,26 @@ class ModelTesterMixin:
assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**inputs_dict)
if self.forward_requires_fresh_args:
first = model(**self.inputs_dict(0))
else:
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.to_tuple()[0]
second = model(**inputs_dict)
if self.forward_requires_fresh_args:
second = model(**self.inputs_dict(0))
else:
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.to_tuple()[0]
@@ -548,15 +605,22 @@ class ModelTesterMixin:
),
)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
if self.forward_requires_fresh_args:
outputs_dict = model(**self.inputs_dict(0))
outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
else:
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)

View File

@@ -16,11 +16,19 @@
import gc
import unittest
import numpy as np
import torch
from parameterized import parameterized
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny
from diffusers import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
ConsistencyDecoderVAE,
StableDiffusionPipeline,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.loading_utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
@@ -30,6 +38,7 @@ from diffusers.utils.testing_utils import (
torch_all_close,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
@@ -269,6 +278,79 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
pass
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
main_input_name = "sample"
base_precision = 1e-2
forward_requires_fresh_args = True
def inputs_dict(self, seed=None):
generator = torch.Generator("cpu")
if seed is not None:
generator.manual_seed(0)
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
return {"sample": image, "generator": generator}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
@property
def init_dict(self):
return {
"encoder_args": {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 4,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
},
"decoder_args": {
"act_fn": "silu",
"add_attention": False,
"block_out_channels": [32, 64],
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
],
"downsample_padding": 1,
"downsample_type": "conv",
"dropout": 0.0,
"in_channels": 7,
"layers_per_block": 1,
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_train_timesteps": 1024,
"out_channels": 6,
"resnet_time_scale_shift": "scale_shift",
"time_embedding_type": "learned",
"up_block_types": [
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"upsample_type": "conv",
},
"scaling_factor": 1,
"block_out_channels": [32, 64],
"latent_channels": 4,
}
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()
@unittest.skip
def test_training(self):
...
@unittest.skip
def test_ema_training(self):
...
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self):
@@ -721,3 +803,94 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_encode_decode(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
vae.to(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
None, :, :, :
].cuda()
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_sd(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_encode_decode_f16(self):
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
vae.to(torch_device)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = (
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
.half()
.cuda()
)
latent = vae.encode(image).latent_dist.mean
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
actual_output = sample[0, :2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_sd_f16(self):
vae = ConsistencyDecoderVAE.from_pretrained(
"openai/consistency-decoder", torch_dtype=torch.float16
) # TODO - update
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
)
pipe.to(torch_device)
out = pipe(
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
).images[0]
actual_output = out[:2, :2, :2].flatten().cpu()
expected_output = torch.tensor(
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
)
assert torch_all_close(actual_output, expected_output, atol=5e-3)