mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
consistency decoder (#5694)
* consistency decoder * rename * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py * uP * Apply suggestions from code review * uP * uP * uP --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -200,6 +200,8 @@
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_tiny
|
||||
title: Tiny AutoEncoder
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2D
|
||||
- local: api/models/transformer_temporal
|
||||
@@ -344,6 +346,8 @@
|
||||
title: Overview
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
title: CMStochasticIterativeScheduler
|
||||
- local: api/schedulers/consistency_decoder
|
||||
title: ConsistencyDecoderScheduler
|
||||
- local: api/schedulers/ddim_inverse
|
||||
title: DDIMInverseScheduler
|
||||
- local: api/schedulers/ddim
|
||||
|
||||
18
docs/source/en/api/models/consistency_decoder_vae.md
Normal file
18
docs/source/en/api/models/consistency_decoder_vae.md
Normal file
@@ -0,0 +1,18 @@
|
||||
# Consistency Decoder
|
||||
|
||||
Consistency decoder can be used to decode the latents from the denoising UNet in the [`StableDiffusionPipeline`]. This decoder was introduced in the [DALL-E 3 technical report](https://openai.com/dall-e-3).
|
||||
|
||||
The original codebase can be found at [openai/consistencydecoder](https://github.com/openai/consistencydecoder).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Inference is only supported for 2 iterations as of now.
|
||||
|
||||
</Tip>
|
||||
|
||||
The pipeline could not have been contributed without the help of [madebyollin](https://github.com/madebyollin) and [mrsteyk](https://github.com/mrsteyk) from [this issue](https://github.com/openai/consistencydecoder/issues/1).
|
||||
|
||||
## ConsistencyDecoderVAE
|
||||
[[autodoc]] ConsistencyDecoderVAE
|
||||
- all
|
||||
- decode
|
||||
9
docs/source/en/api/schedulers/consistency_decoder.md
Normal file
9
docs/source/en/api/schedulers/consistency_decoder.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# ConsistencyDecoderScheduler
|
||||
|
||||
This scheduler is a part of the [`ConsistencyDecoderPipeline`] and was introduced in [DALL-E 3](https://openai.com/dall-e-3).
|
||||
|
||||
The original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models).
|
||||
|
||||
|
||||
## ConsistencyDecoderScheduler
|
||||
[[autodoc]] schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler
|
||||
1128
scripts/convert_consistency_decoder.py
Normal file
1128
scripts/convert_consistency_decoder.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -77,6 +77,7 @@ else:
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderTiny",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
@@ -443,6 +444,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
|
||||
@@ -24,6 +24,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
@@ -50,6 +51,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
@@ -294,7 +294,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
|
||||
365
src/diffusers/models/consistency_decoder_vae.py
Normal file
365
src/diffusers/models/consistency_decoder_vae.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..schedulers import ConsistencyDecoderScheduler
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .modeling_utils import ModelMixin
|
||||
from .unet_2d import UNet2DModel
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsistencyDecoderVAEOutput(BaseOutput):
|
||||
"""
|
||||
Output of encoding method.
|
||||
|
||||
Args:
|
||||
latent_dist (`DiagonalGaussianDistribution`):
|
||||
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
||||
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||
"""
|
||||
|
||||
latent_dist: "DiagonalGaussianDistribution"
|
||||
|
||||
|
||||
class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
The consistency decoder used with DALL-E 3.
|
||||
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE
|
||||
|
||||
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=pipe.torch_dtype)
|
||||
>>> pipe = StableDiffusionPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> pipe("horse", generator=torch.manual_seed(0)).images
|
||||
```
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(**encoder_args)
|
||||
self.decoder_unet = UNet2DModel(**decoder_args)
|
||||
self.decoder_scheduler = ConsistencyDecoderScheduler()
|
||||
self.register_buffer(
|
||||
"means",
|
||||
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.use_tiling = use_tiling
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
|
||||
def disable_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.enable_tiling(False)
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
|
||||
def enable_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
|
||||
def disable_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
|
||||
is returned.
|
||||
"""
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self,
|
||||
z: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
num_inference_steps=2,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
z = (z * self.config.scaling_factor - self.means) / self.stds
|
||||
|
||||
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
|
||||
z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
|
||||
|
||||
batch_size, _, height, width = z.shape
|
||||
|
||||
self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
|
||||
|
||||
x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
|
||||
(batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
|
||||
)
|
||||
|
||||
for t in self.decoder_scheduler.timesteps:
|
||||
model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
|
||||
model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
|
||||
prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
|
||||
x_t = prev_sample
|
||||
|
||||
x_0 = x_t
|
||||
|
||||
if not return_dict:
|
||||
return (x_0,)
|
||||
|
||||
return DecoderOutput(sample=x_0)
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
|
||||
def blend_v(self, a, b, blend_extent):
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
# Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
|
||||
def blend_h(self, a, b, blend_extent):
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
|
||||
plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
|
||||
otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
moments = torch.cat(result_rows, dim=2)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return ConsistencyDecoderVAEOutput(latent_dist=posterior)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, generator=generator).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -117,6 +117,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
add_attention: bool = True,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -144,6 +145,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
elif time_embedding_type == "positional":
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
elif time_embedding_type == "learned":
|
||||
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
|
||||
@@ -852,7 +852,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -893,7 +893,9 @@ class AltDiffusionImg2ImgPipeline(
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -6,7 +6,9 @@ from ...utils import (
|
||||
)
|
||||
|
||||
|
||||
_import_structure = {"pipeline_consistency_models": ["ConsistencyModelPipeline"]}
|
||||
_import_structure = {
|
||||
"pipeline_consistency_models": ["ConsistencyModelPipeline"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_consistency_models import ConsistencyModelPipeline
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1058,7 +1058,9 @@ class StableDiffusionControlNetPipeline(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -1138,7 +1138,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -1405,7 +1405,9 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -838,7 +838,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -885,7 +885,9 @@ class StableDiffusionImg2ImgPipeline(
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -1159,7 +1159,9 @@ class StableDiffusionInpaintPipeline(
|
||||
init_image = self._encode_vae_image(init_image, generator=generator)
|
||||
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
|
||||
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
|
||||
image = self.vae.decode(
|
||||
latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
|
||||
)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -38,6 +38,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_modules.update(get_objects_from_module(dummy_pt_objects))
|
||||
|
||||
else:
|
||||
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
||||
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
||||
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
||||
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
|
||||
@@ -128,6 +129,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
||||
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddim_inverse import DDIMInverseScheduler
|
||||
|
||||
180
src/diffusers/schedulers/scheduling_consistency_decoder.py
Normal file
180
src/diffusers/schedulers/scheduling_consistency_decoder.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsistencyDecoderSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1024,
|
||||
sigma_data: float = 0.5,
|
||||
):
|
||||
betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||
|
||||
sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
|
||||
|
||||
sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
|
||||
|
||||
self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
|
||||
self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
|
||||
self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
):
|
||||
if num_inference_steps != 2:
|
||||
raise ValueError("Currently more than 2 inference steps are not supported.")
|
||||
|
||||
self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device)
|
||||
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
|
||||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
|
||||
self.c_skip = self.c_skip.to(device)
|
||||
self.c_out = self.c_out.to(device)
|
||||
self.c_in = self.c_in.to(device)
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
return sample * self.c_in[timestep]
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`,
|
||||
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
|
||||
a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
|
||||
|
||||
timestep_idx = torch.where(self.timesteps == timestep)[0]
|
||||
|
||||
if timestep_idx == len(self.timesteps) - 1:
|
||||
prev_sample = x_0
|
||||
else:
|
||||
noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device)
|
||||
prev_sample = (
|
||||
self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0
|
||||
+ self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample)
|
||||
@@ -47,6 +47,21 @@ class AutoencoderTiny(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ConsistencyDecoderVAE(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -196,11 +196,15 @@ class UNetTesterMixin:
|
||||
class ModelTesterMixin:
|
||||
main_input_name = None # overwrite in model specific tester class
|
||||
base_precision = 1e-3
|
||||
forward_requires_fresh_args = False
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
if hasattr(model, "set_default_attn_processor"):
|
||||
model.set_default_attn_processor()
|
||||
model.to(torch_device)
|
||||
@@ -214,11 +218,18 @@ class ModelTesterMixin:
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**inputs_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
image = model(**self.inputs_dict(0))
|
||||
else:
|
||||
image = model(**inputs_dict)
|
||||
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**inputs_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
new_image = new_model(**self.inputs_dict(0))
|
||||
else:
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
@@ -275,8 +286,11 @@ class ModelTesterMixin:
|
||||
)
|
||||
def test_set_xformers_attn_processor_for_determinism(self):
|
||||
torch.use_deterministic_algorithms(False)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
@@ -286,17 +300,26 @@ class ModelTesterMixin:
|
||||
model.set_default_attn_processor()
|
||||
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
if self.forward_requires_fresh_args:
|
||||
output = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
if self.forward_requires_fresh_args:
|
||||
output_2 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_attn_processor(XFormersAttnProcessor())
|
||||
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_3 = model(**inputs_dict)[0]
|
||||
if self.forward_requires_fresh_args:
|
||||
output_3 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_3 = model(**inputs_dict)[0]
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
@@ -307,8 +330,12 @@ class ModelTesterMixin:
|
||||
@require_torch_gpu
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
torch.use_deterministic_algorithms(False)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
@@ -317,22 +344,34 @@ class ModelTesterMixin:
|
||||
|
||||
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict)[0]
|
||||
if self.forward_requires_fresh_args:
|
||||
output_1 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_1 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_default_attn_processor()
|
||||
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
if self.forward_requires_fresh_args:
|
||||
output_2 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_attn_processor(AttnProcessor2_0())
|
||||
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_4 = model(**inputs_dict)[0]
|
||||
if self.forward_requires_fresh_args:
|
||||
output_4 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_4 = model(**inputs_dict)[0]
|
||||
|
||||
model.set_attn_processor(AttnProcessor())
|
||||
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
|
||||
with torch.no_grad():
|
||||
output_5 = model(**inputs_dict)[0]
|
||||
if self.forward_requires_fresh_args:
|
||||
output_5 = model(**self.inputs_dict(0))[0]
|
||||
else:
|
||||
output_5 = model(**inputs_dict)[0]
|
||||
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
@@ -342,9 +381,12 @@ class ModelTesterMixin:
|
||||
assert torch.allclose(output_2, output_5, atol=self.base_precision)
|
||||
|
||||
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
if hasattr(model, "set_default_attn_processor"):
|
||||
model.set_default_attn_processor()
|
||||
|
||||
@@ -367,11 +409,17 @@ class ModelTesterMixin:
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**inputs_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
image = model(**self.inputs_dict(0))
|
||||
else:
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image.to_tuple()[0]
|
||||
|
||||
new_image = new_model(**inputs_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
new_image = new_model(**self.inputs_dict(0))
|
||||
else:
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
@@ -405,17 +453,26 @@ class ModelTesterMixin:
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, expected_max_diff=1e-5):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
first = model(**inputs_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
first = model(**self.inputs_dict(0))
|
||||
else:
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.to_tuple()[0]
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if self.forward_requires_fresh_args:
|
||||
second = model(**self.inputs_dict(0))
|
||||
else:
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
@@ -548,15 +605,22 @@ class ModelTesterMixin:
|
||||
),
|
||||
)
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
if self.forward_requires_fresh_args:
|
||||
model = self.model_class(**self.init_dict)
|
||||
else:
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
if self.forward_requires_fresh_args:
|
||||
outputs_dict = model(**self.inputs_dict(0))
|
||||
outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
|
||||
else:
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
|
||||
@@ -16,11 +16,19 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.loading_utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
@@ -30,6 +38,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
@@ -269,6 +278,79 @@ class AutoencoderTinyTests(ModelTesterMixin, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ConsistencyDecoderVAE
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
forward_requires_fresh_args = True
|
||||
|
||||
def inputs_dict(self, seed=None):
|
||||
generator = torch.Generator("cpu")
|
||||
if seed is not None:
|
||||
generator.manual_seed(0)
|
||||
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
|
||||
|
||||
return {"sample": image, "generator": generator}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def init_dict(self):
|
||||
return {
|
||||
"encoder_args": {
|
||||
"block_out_channels": [32, 64],
|
||||
"in_channels": 3,
|
||||
"out_channels": 4,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
},
|
||||
"decoder_args": {
|
||||
"act_fn": "silu",
|
||||
"add_attention": False,
|
||||
"block_out_channels": [32, 64],
|
||||
"down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"ResnetDownsampleBlock2D",
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"downsample_type": "conv",
|
||||
"dropout": 0.0,
|
||||
"in_channels": 7,
|
||||
"layers_per_block": 1,
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_train_timesteps": 1024,
|
||||
"out_channels": 6,
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"time_embedding_type": "learned",
|
||||
"up_block_types": [
|
||||
"ResnetUpsampleBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
],
|
||||
"upsample_type": "conv",
|
||||
},
|
||||
"scaling_factor": 1,
|
||||
"block_out_channels": [32, 64],
|
||||
"latent_channels": 4,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.init_dict, self.inputs_dict()
|
||||
|
||||
@unittest.skip
|
||||
def test_training(self):
|
||||
...
|
||||
|
||||
@unittest.skip
|
||||
def test_ema_training(self):
|
||||
...
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderTinyIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
@@ -721,3 +803,94 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
|
||||
|
||||
@slow
|
||||
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_encode_decode(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
||||
vae.to(torch_device)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
).resize((256, 256))
|
||||
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
|
||||
None, :, :, :
|
||||
].cuda()
|
||||
|
||||
latent = vae.encode(image).latent_dist.mean
|
||||
|
||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
||||
|
||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_sd(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
|
||||
out = pipe(
|
||||
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0]
|
||||
|
||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_encode_decode_f16(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||
"openai/consistency-decoder", torch_dtype=torch.float16
|
||||
) # TODO - update
|
||||
vae.to(torch_device)
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/img2img/sketch-mountains-input.jpg"
|
||||
).resize((256, 256))
|
||||
image = (
|
||||
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
|
||||
.half()
|
||||
.cuda()
|
||||
)
|
||||
|
||||
latent = vae.encode(image).latent_dist.mean
|
||||
|
||||
sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
|
||||
|
||||
actual_output = sample[0, :2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor(
|
||||
[-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_sd_f16(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained(
|
||||
"openai/consistency-decoder", torch_dtype=torch.float16
|
||||
) # TODO - update
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
|
||||
out = pipe(
|
||||
"horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0]
|
||||
|
||||
actual_output = out[:2, :2, :2].flatten().cpu()
|
||||
expected_output = torch.tensor(
|
||||
[0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
Reference in New Issue
Block a user