1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add Consistency Models Pipeline (#3492)

* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <ayushmangal@microsoft.com>
Co-authored-by: Ayush Mangal <43698245+ayushtues@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
dg845
2023-07-05 10:33:58 -07:00
committed by GitHub
parent 07c9a08e67
commit aed7499a8d
17 changed files with 1710 additions and 13 deletions

View File

@@ -184,6 +184,8 @@
title: Audio Diffusion
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/cycle_diffusion
@@ -274,6 +276,8 @@
- sections:
- local: api/schedulers/overview
title: Overview
- local: api/schedulers/cm_stochastic_iterative
title: Consistency Model Multistep Scheduler
- local: api/schedulers/ddim
title: DDIM
- local: api/schedulers/ddim_inverse

View File

@@ -0,0 +1,87 @@
# Consistency Models
Consistency Models were proposed in [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.
The abstract of the [paper](https://arxiv.org/pdf/2303.01469.pdf) is as follows:
*Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256. *
Resources:
* [Paper](https://arxiv.org/abs/2303.01469)
* [Original Code](https://github.com/openai/consistency_models)
Available Checkpoints are:
- *cd_imagenet64_l2 (64x64 resolution)* [openai/consistency-model-pipelines](https://huggingface.co/openai/consistency-model-pipelines)
- *cd_imagenet64_lpips (64x64 resolution)* [openai/diffusers-cd_imagenet64_lpips](https://huggingface.co/openai/diffusers-cd_imagenet64_lpips)
- *ct_imagenet64 (64x64 resolution)* [openai/diffusers-ct_imagenet64](https://huggingface.co/openai/diffusers-ct_imagenet64)
- *cd_bedroom256_l2 (256x256 resolution)* [openai/diffusers-cd_bedroom256_l2](https://huggingface.co/openai/diffusers-cd_bedroom256_l2)
- *cd_bedroom256_lpips (256x256 resolution)* [openai/diffusers-cd_bedroom256_lpips](https://huggingface.co/openai/diffusers-cd_bedroom256_lpips)
- *ct_bedroom256 (256x256 resolution)* [openai/diffusers-ct_bedroom256](https://huggingface.co/openai/diffusers-ct_bedroom256)
- *cd_cat256_l2 (256x256 resolution)* [openai/diffusers-cd_cat256_l2](https://huggingface.co/openai/diffusers-cd_cat256_l2)
- *cd_cat256_lpips (256x256 resolution)* [openai/diffusers-cd_cat256_lpips](https://huggingface.co/openai/diffusers-cd_cat256_lpips)
- *ct_cat256 (256x256 resolution)* [openai/diffusers-ct_cat256](https://huggingface.co/openai/diffusers-ct_cat256)
## Available Pipelines
| Pipeline | Tasks | Demo | Colab |
|:---:|:---:|:---:|:---:|
| [ConsistencyModelPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_consistency_models.py) | *Unconditional Image Generation* | | |
This pipeline was contributed by our community members [dg845](https://github.com/dg845) and [ayushtues](https://huggingface.co/ayushtues) :heart:
## Usage Example
```python
import torch
from diffusers import ConsistencyModelPipeline
device = "cuda"
# Load the cd_imagenet64_l2 checkpoint.
model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
# Onestep Sampling
image = pipe(num_inference_steps=1).images[0]
image.save("consistency_model_onestep_sample.png")
# Onestep sampling, class-conditional image generation
# ImageNet-64 class label 145 corresponds to king penguins
image = pipe(num_inference_steps=1, class_labels=145).images[0]
image.save("consistency_model_onestep_sample_penguin.png")
# Multistep sampling, class-conditional image generation
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo.
# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
image = pipe(timesteps=[22, 0], class_labels=145).images[0]
image.save("consistency_model_multistep_sample_penguin.png")
```
For an additional speed-up, one can also make use of `torch.compile`. Multiple images can be generated in <1 second as follows:
```py
import torch
from diffusers import ConsistencyModelPipeline
device = "cuda"
# Load the cd_bedroom256_lpips checkpoint.
model_id_or_path = "openai/diffusers-cd_bedroom256_lpips"
pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
pipe.to(device)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
# Multistep sampling
# Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
# https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L83
for _ in range(10):
image = pipe(timesteps=[17, 0]).images[0]
image.show()
```
## ConsistencyModelPipeline
[[autodoc]] ConsistencyModelPipeline
- all
- __call__

View File

@@ -0,0 +1,11 @@
# Consistency Model Multistep Scheduler
## Overview
Multistep and onestep scheduler (Algorithm 1) introduced alongside consistency models in the paper [Consistency Models](https://arxiv.org/abs/2303.01469) by Yang Song, Prafulla Dhariwal, Mark Chen, and Ilya Sutskever.
Based on the [original consistency models implementation](https://github.com/openai/consistency_models).
Should generate good samples from [`ConsistencyModelPipeline`] in one or a small number of steps.
## CMStochasticIterativeScheduler
[[autodoc]] CMStochasticIterativeScheduler

View File

@@ -0,0 +1,313 @@
import argparse
import os
import torch
from diffusers import (
CMStochasticIterativeScheduler,
ConsistencyModelPipeline,
UNet2DModel,
)
TEST_UNET_CONFIG = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 2,
"num_class_embeds": 1000,
"block_out_channels": [32, 64],
"attention_head_dim": 8,
"down_block_types": [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"upsample_type": "resnet",
"downsample_type": "resnet",
}
IMAGENET_64_UNET_CONFIG = {
"sample_size": 64,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 3,
"num_class_embeds": 1000,
"block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"upsample_type": "resnet",
"downsample_type": "resnet",
}
LSUN_256_UNET_CONFIG = {
"sample_size": 256,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 2,
"num_class_embeds": None,
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "default",
"upsample_type": "resnet",
"downsample_type": "resnet",
}
CD_SCHEDULER_CONFIG = {
"num_train_timesteps": 40,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
CT_IMAGENET_64_SCHEDULER_CONFIG = {
"num_train_timesteps": 201,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
CT_LSUN_256_SCHEDULER_CONFIG = {
"num_train_timesteps": 151,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
if has_skip:
new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
return new_checkpoint
def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
)
new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
return new_checkpoint
def con_pt_to_diffuser(checkpoint_path: str, unet_config):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
if unet_config["num_class_embeds"] is not None:
new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
down_block_types = unet_config["down_block_types"]
layers_per_block = unet_config["layers_per_block"]
attention_head_dim = unet_config["attention_head_dim"]
channels_list = unet_config["block_out_channels"]
current_layer = 1
prev_channels = channels_list[0]
for i, layer_type in enumerate(down_block_types):
current_channels = channels_list[i]
downsample_block_has_skip = current_channels != prev_channels
if layer_type == "ResnetDownsampleBlock2D":
for j in range(layers_per_block):
new_prefix = f"down_blocks.{i}.resnets.{j}"
old_prefix = f"input_blocks.{current_layer}.0"
has_skip = True if j == 0 and downsample_block_has_skip else False
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
current_layer += 1
elif layer_type == "AttnDownBlock2D":
for j in range(layers_per_block):
new_prefix = f"down_blocks.{i}.resnets.{j}"
old_prefix = f"input_blocks.{current_layer}.0"
has_skip = True if j == 0 and downsample_block_has_skip else False
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
new_prefix = f"down_blocks.{i}.attentions.{j}"
old_prefix = f"input_blocks.{current_layer}.1"
new_checkpoint = convert_attention(
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
)
current_layer += 1
if i != len(down_block_types) - 1:
new_prefix = f"down_blocks.{i}.downsamplers.0"
old_prefix = f"input_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
current_layer += 1
prev_channels = current_channels
# hardcoded the mid-block for now
new_prefix = "mid_block.resnets.0"
old_prefix = "middle_block.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_prefix = "mid_block.attentions.0"
old_prefix = "middle_block.1"
new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
new_prefix = "mid_block.resnets.1"
old_prefix = "middle_block.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
current_layer = 0
up_block_types = unet_config["up_block_types"]
for i, layer_type in enumerate(up_block_types):
if layer_type == "ResnetUpsampleBlock2D":
for j in range(layers_per_block + 1):
new_prefix = f"up_blocks.{i}.resnets.{j}"
old_prefix = f"output_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
current_layer += 1
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
new_prefix = f"up_blocks.{i}.resnets.{j}"
old_prefix = f"output_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
new_prefix = f"up_blocks.{i}.attentions.{j}"
old_prefix = f"output_blocks.{current_layer}.1"
new_checkpoint = convert_attention(
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
)
current_layer += 1
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
)
parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
args = parser.parse_args()
args.class_cond = str2bool(args.class_cond)
ckpt_name = os.path.basename(args.unet_path)
print(f"Checkpoint: {ckpt_name}")
# Get U-Net config
if "imagenet64" in ckpt_name:
unet_config = IMAGENET_64_UNET_CONFIG
elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
unet_config = LSUN_256_UNET_CONFIG
elif "test" in ckpt_name:
unet_config = TEST_UNET_CONFIG
else:
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
if not args.class_cond:
unet_config["num_class_embeds"] = None
converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
image_unet = UNet2DModel(**unet_config)
image_unet.load_state_dict(converted_unet_ckpt)
# Get scheduler config
if "cd" in ckpt_name or "test" in ckpt_name:
scheduler_config = CD_SCHEDULER_CONFIG
elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
else:
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
consistency_model.save_pretrained(args.dump_path)

View File

@@ -58,6 +58,7 @@ else:
)
from .pipelines import (
AudioPipelineOutput,
ConsistencyModelPipeline,
DanceDiffusionPipeline,
DDIMPipeline,
DDPMPipeline,
@@ -72,6 +73,7 @@ else:
ScoreSdeVePipeline,
)
from .schedulers import (
CMStochasticIterativeScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler,

View File

@@ -66,6 +66,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
downsample_type (`str`, *optional*, defaults to `conv`):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
@@ -96,6 +100,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
@@ -168,6 +174,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
)
self.down_blocks.append(down_block)
@@ -207,6 +214,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

View File

@@ -51,6 +51,7 @@ def get_down_block(
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
attention_head_dim=None,
downsample_type=None,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
@@ -88,18 +89,22 @@ def get_down_block(
output_scale_factor=resnet_out_scale_factor,
)
elif down_block_type == "AttnDownBlock2D":
if add_downsample is False:
downsample_type = None
else:
downsample_type = downsample_type or "conv" # default to 'conv'
return AttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
)
elif down_block_type == "CrossAttnDownBlock2D":
if cross_attention_dim is None:
@@ -239,6 +244,7 @@ def get_up_block(
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
attention_head_dim=None,
upsample_type=None,
):
# If attn head dim is not defined, we default it to the number of heads
if attention_head_dim is None:
@@ -319,18 +325,23 @@ def get_up_block(
cross_attention_norm=cross_attention_norm,
)
elif up_block_type == "AttnUpBlock2D":
if add_upsample is False:
upsample_type = None
else:
upsample_type = upsample_type or "conv" # default to 'conv'
return AttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
)
elif up_block_type == "SkipUpBlock2D":
return SkipUpBlock2D(
@@ -747,11 +758,12 @@ class AttnDownBlock2D(nn.Module):
attention_head_dim=1,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
downsample_type="conv",
):
super().__init__()
resnets = []
attentions = []
self.downsample_type = downsample_type
if attention_head_dim is None:
logger.warn(
@@ -793,7 +805,7 @@ class AttnDownBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
if downsample_type == "conv":
self.downsamplers = nn.ModuleList(
[
Downsample2D(
@@ -801,6 +813,24 @@ class AttnDownBlock2D(nn.Module):
)
]
)
elif downsample_type == "resnet":
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
down=True,
)
]
)
else:
self.downsamplers = None
@@ -810,11 +840,14 @@ class AttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
output_states += (hidden_states,)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if self.downsample_type == "resnet":
hidden_states = downsampler(hidden_states, temb=temb)
else:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
@@ -1860,12 +1893,14 @@ class AttnUpBlock2D(nn.Module):
resnet_pre_norm: bool = True,
attention_head_dim=1,
output_scale_factor=1.0,
add_upsample=True,
upsample_type="conv",
):
super().__init__()
resnets = []
attentions = []
self.upsample_type = upsample_type
if attention_head_dim is None:
logger.warn(
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
@@ -1908,8 +1943,26 @@ class AttnUpBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
if upsample_type == "conv":
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
elif upsample_type == "resnet":
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
up=True,
)
]
)
else:
self.upsamplers = None
@@ -1925,7 +1978,10 @@ class AttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
if self.upsample_type == "resnet":
hidden_states = upsampler(hidden_states, temb=temb)
else:
hidden_states = upsampler(hidden_states)
return hidden_states

View File

@@ -16,6 +16,7 @@ try:
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .consistency_models import ConsistencyModelPipeline
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline

View File

@@ -0,0 +1 @@
from .pipeline_consistency_models import ConsistencyModelPipeline

View File

@@ -0,0 +1,337 @@
from typing import Callable, List, Optional, Union
import torch
from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler
from ...utils import (
is_accelerate_available,
is_accelerate_version,
logging,
randn_tensor,
replace_example_docstring,
)
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import ConsistencyModelPipeline
>>> device = "cuda"
>>> # Load the cd_imagenet64_l2 checkpoint.
>>> model_id_or_path = "openai/diffusers-cd_imagenet64_l2"
>>> pipe = ConsistencyModelPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
>>> pipe.to(device)
>>> # Onestep Sampling
>>> image = pipe(num_inference_steps=1).images[0]
>>> image.save("cd_imagenet64_l2_onestep_sample.png")
>>> # Onestep sampling, class-conditional image generation
>>> # ImageNet-64 class label 145 corresponds to king penguins
>>> image = pipe(num_inference_steps=1, class_labels=145).images[0]
>>> image.save("cd_imagenet64_l2_onestep_sample_penguin.png")
>>> # Multistep sampling, class-conditional image generation
>>> # Timesteps can be explicitly specified; the particular timesteps below are from the original Github repo:
>>> # https://github.com/openai/consistency_models/blob/main/scripts/launch.sh#L77
>>> image = pipe(num_inference_steps=None, timesteps=[22, 0], class_labels=145).images[0]
>>> image.save("cd_imagenet64_l2_multistep_sample_penguin.png")
```
"""
class ConsistencyModelPipeline(DiffusionPipeline):
r"""
Pipeline for consistency models for unconditional or class-conditional image generation, as introduced in [1].
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
[1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
https://arxiv.org/pdf/2303.01469
Args:
unet ([`UNet2DModel`]):
Unconditional or class-conditional U-Net architecture to denoise image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the image latents. Currently only compatible
with [`CMStochasticIterativeScheduler`].
"""
def __init__(self, unet: UNet2DModel, scheduler: CMStochasticIterativeScheduler) -> None:
super().__init__()
self.register_modules(
unet=unet,
scheduler=scheduler,
)
self.safety_checker = None
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
`torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
Note that offloading happens on a submodule basis. Memory savings are higher than with
`enable_model_cpu_offload`, but performance is lower.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload
else:
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
for cpu_offloaded_model in [self.unet]:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
from accelerate import cpu_offload_with_hook
else:
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
device = torch.device(f"cuda:{gpu_id}")
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
hook = None
for cpu_offloaded_model in [self.unet]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
if self.safety_checker is not None:
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
# We'll offload the last model manually.
self.final_offload_hook = hook
@property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self):
r"""
Returns the device on which the pipeline's models will be executed. After calling
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Follows diffusers.VaeImageProcessor.postprocess
def postprocess_image(self, sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = self.numpy_to_pil(sample)
return sample
def prepare_class_labels(self, batch_size, device, class_labels=None):
if self.unet.config.num_class_embeds is not None:
if isinstance(class_labels, list):
class_labels = torch.tensor(class_labels, dtype=torch.int)
elif isinstance(class_labels, int):
assert batch_size == 1, "Batch size must be 1 if classes is an int"
class_labels = torch.tensor([class_labels], dtype=torch.int)
elif class_labels is None:
# Randomly generate batch_size class labels
# TODO: should use generator here? int analogue of randn_tensor is not exposed in ...utils
class_labels = torch.randint(0, self.unet.config.num_class_embeds, size=(batch_size,))
class_labels = class_labels.to(device)
else:
class_labels = None
return class_labels
def check_inputs(self, num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps):
if num_inference_steps is None and timesteps is None:
raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
if num_inference_steps is not None and timesteps is not None:
logger.warning(
f"Both `num_inference_steps`: {num_inference_steps} and `timesteps`: {timesteps} are supplied;"
" `timesteps` will be used over `num_inference_steps`."
)
if latents is not None:
expected_shape = (batch_size, 3, img_size, img_size)
if latents.shape != expected_shape:
raise ValueError(f"The shape of latents is {latents.shape} but is expected to be {expected_shape}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
batch_size: int = 1,
class_labels: Optional[Union[torch.Tensor, List[int], int]] = None,
num_inference_steps: int = 1,
timesteps: List[int] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
):
r"""
Args:
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
class_labels (`torch.Tensor` or `List[int]` or `int`, *optional*):
Optional class labels for conditioning class-conditional consistency models. Will not be used if the
model is not class-conditional.
num_inference_steps (`int`, *optional*, defaults to 1):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Prepare call parameters
img_size = self.unet.config.sample_size
device = self._execution_device
# 1. Check inputs
self.check_inputs(num_inference_steps, timesteps, latents, batch_size, img_size, callback_steps)
# 2. Prepare image latents
# Sample image latents x_0 ~ N(0, sigma_0^2 * I)
sample = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=img_size,
width=img_size,
dtype=self.unet.dtype,
device=device,
generator=generator,
latents=latents,
)
# 3. Handle class_labels for class-conditional models
class_labels = self.prepare_class_labels(batch_size, device, class_labels=class_labels)
# 4. Prepare timesteps
if timesteps is not None:
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps)
timesteps = self.scheduler.timesteps
# 5. Denoising loop
# Multistep sampling: implements Algorithm 1 in the paper
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
scaled_sample = self.scheduler.scale_model_input(sample, t)
model_output = self.unet(scaled_sample, t, class_labels=class_labels, return_dict=False)[0]
sample = self.scheduler.step(model_output, t, sample, generator=generator)[0]
# call the callback, if provided
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, sample)
# 6. Post-process image sample
image = self.postprocess_image(sample, output_type=output_type)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -28,6 +28,7 @@ try:
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler

View File

@@ -0,0 +1,380 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class CMStochasticIterativeSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
"""
Multistep and onestep sampling for consistency models from Song et al. 2023 [1]. This implements Algorithm 1 in the
paper [1].
[1] Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya. "Consistency Models"
https://arxiv.org/pdf/2303.01469 [2] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based
Generative Models." https://arxiv.org/abs/2206.00364
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
sigma_min (`float`):
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the original implementation.
sigma_max (`float`):
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the original implementation.
sigma_data (`float`):
The standard deviation of the data distribution, following the EDM paper [2]. This was set to 0.5 in the
original implementation, which is also the original value suggested in the EDM paper.
s_noise (`float`):
The amount of additional noise to counteract loss of detail during sampling. A reasonable range is [1.000,
1.011]. This was set to 1.0 in the original implementation.
rho (`float`):
The rho parameter used for calculating the Karras sigma schedule, introduced in the EDM paper [2]. This was
set to 7.0 in the original implementation, which is also the original value suggested in the EDM paper.
clip_denoised (`bool`):
Whether to clip the denoised outputs to `(-1, 1)`. Defaults to `True`.
timesteps (`List` or `np.ndarray` or `torch.Tensor`, *optional*):
Optionally, an explicit timestep schedule can be specified. The timesteps are expected to be in increasing
order.
"""
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 40,
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
s_noise: float = 1.0,
rho: float = 7.0,
clip_denoised: bool = True,
):
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
ramp = np.linspace(0, 1, num_train_timesteps)
sigmas = self._convert_to_karras(ramp)
timesteps = self.sigma_to_t(sigmas)
# setable values
self.num_inference_steps = None
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps)
self.custom_timesteps = False
self.is_scale_input_called = False
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
return indices.item()
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`, following the EDM model.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
Returns:
`torch.FloatTensor`: scaled input sample
"""
# Get sigma corresponding to timestep
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_idx = self.index_for_timestep(timestep)
sigma = self.sigmas[step_idx]
sample = sample / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
self.is_scale_input_called = True
return sample
def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
"""
Gets scaled timesteps from the Karras sigmas, for input to the consistency model.
Args:
sigmas (`float` or `np.ndarray`): single Karras sigma or array of Karras sigmas
Returns:
`float` or `np.ndarray`: scaled input timestep or scaled input timestep array
"""
if not isinstance(sigmas, np.ndarray):
sigmas = np.array(sigmas, dtype=np.float64)
timesteps = 1000 * 0.25 * np.log(sigmas + 1e-44)
return timesteps
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, optional):
custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps is used. If passed, `num_inference_steps`
must be `None`.
"""
if num_inference_steps is None and timesteps is None:
raise ValueError("Exactly one of `num_inference_steps` or `timesteps` must be supplied.")
if num_inference_steps is not None and timesteps is not None:
raise ValueError("Can only pass one of `num_inference_steps` or `timesteps`.")
# Follow DDPMScheduler custom timesteps logic
if timesteps is not None:
for i in range(1, len(timesteps)):
if timesteps[i] >= timesteps[i - 1]:
raise ValueError("`timesteps` must be in descending order.")
if timesteps[0] >= self.config.num_train_timesteps:
raise ValueError(
f"`timesteps` must start before `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps}."
)
timesteps = np.array(timesteps, dtype=np.int64)
self.custom_timesteps = True
else:
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.custom_timesteps = False
# Map timesteps to Karras sigmas directly for multistep sampling
# See https://github.com/openai/consistency_models/blob/main/cm/karras_diffusion.py#L675
num_train_timesteps = self.config.num_train_timesteps
ramp = timesteps[::-1].copy()
ramp = ramp / (num_train_timesteps - 1)
sigmas = self._convert_to_karras(ramp)
timesteps = self.sigma_to_t(sigmas)
sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas).to(device=device)
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32)
else:
self.timesteps = torch.from_numpy(timesteps).to(device=device)
# Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp):
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = self.config.sigma_min
sigma_max: float = self.config.sigma_max
rho = self.config.rho
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def get_scalings(self, sigma):
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out
def get_scalings_for_boundary_condition(self, sigma):
"""
Gets the scalings used in the consistency model parameterization, following Appendix C of the original paper.
This enforces the consistency model boundary condition.
Note that `epsilon` in the equations for c_skip and c_out is set to sigma_min.
Args:
sigma (`torch.FloatTensor`):
The current sigma in the Karras sigma schedule.
Returns:
`tuple`:
A two-element tuple where c_skip (which weights the current sample) is the first element and c_out
(which weights the consistency model output) is the second element.
"""
sigma_min = self.config.sigma_min
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / ((sigma - sigma_min) ** 2 + sigma_data**2)
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
generator (`torch.Generator`, *optional*): Random number generator.
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.CMStochasticIterativeSchedulerOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is the sample tensor.
"""
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
f" `{self.__class__}.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not self.is_scale_input_called:
logger.warning(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
sigma_min = self.config.sigma_min
sigma_max = self.config.sigma_max
step_index = self.index_for_timestep(timestep)
# sigma_next corresponds to next_t in original implementation
sigma = self.sigmas[step_index]
if step_index + 1 < self.config.num_train_timesteps:
sigma_next = self.sigmas[step_index + 1]
else:
# Set sigma_next to sigma_min
sigma_next = self.sigmas[-1]
# Get scalings for boundary conditions
c_skip, c_out = self.get_scalings_for_boundary_condition(sigma)
# 1. Denoise model output using boundary conditions
denoised = c_out * model_output + c_skip * sample
if self.config.clip_denoised:
denoised = denoised.clamp(-1, 1)
# 2. Sample z ~ N(0, s_noise^2 * I)
# Noise is not used for onestep sampling.
if len(self.timesteps) > 1:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
)
else:
noise = torch.zeros_like(model_output)
z = noise * self.config.s_noise
sigma_hat = sigma_next.clamp(min=sigma_min, max=sigma_max)
# 3. Return noisy sample
# tau = sigma_hat, eps = sigma_min
prev_sample = denoised + z * (sigma_hat**2 - sigma_min**2) ** 0.5
if not return_dict:
return (prev_sample,)
return CMStochasticIterativeSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -210,6 +210,21 @@ class AudioPipelineOutput(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ConsistencyModelPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DanceDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
@@ -390,6 +405,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class CMStochasticIterativeScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMInverseScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -0,0 +1,288 @@
import gc
import unittest
import numpy as np
import torch
from torch.backends.cuda import sdp_kernel
from diffusers import (
CMStochasticIterativeScheduler,
ConsistencyModelPipeline,
UNet2DModel,
)
from diffusers.utils import randn_tensor, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_2, require_torch_gpu
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class ConsistencyModelPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = ConsistencyModelPipeline
params = UNCONDITIONAL_IMAGE_GENERATION_PARAMS
batch_params = UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS
# Override required_optional_params to remove num_images_per_prompt
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"output_type",
"return_dict",
"callback",
"callback_steps",
]
)
@property
def dummy_uncond_unet(self):
unet = UNet2DModel.from_pretrained(
"diffusers/consistency-models-test",
subfolder="test_unet",
)
return unet
@property
def dummy_cond_unet(self):
unet = UNet2DModel.from_pretrained(
"diffusers/consistency-models-test",
subfolder="test_unet_class_cond",
)
return unet
def get_dummy_components(self, class_cond=False):
if class_cond:
unet = self.dummy_cond_unet
else:
unet = self.dummy_uncond_unet
# Default to CM multistep sampler
scheduler = CMStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=80.0,
)
components = {
"unet": unet,
"scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"batch_size": 1,
"num_inference_steps": None,
"timesteps": [22, 0],
"generator": generator,
"output_type": "np",
}
return inputs
def test_consistency_model_pipeline_multistep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = ConsistencyModelPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_consistency_model_pipeline_multistep_class_cond(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(class_cond=True)
pipe = ConsistencyModelPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["class_labels"] = 0
image = pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.3572, 0.6273, 0.4031, 0.3961, 0.4321, 0.5730, 0.5266, 0.4780, 0.5004])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_consistency_model_pipeline_onestep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = ConsistencyModelPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 1
inputs["timesteps"] = None
image = pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_consistency_model_pipeline_onestep_class_cond(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components(class_cond=True)
pipe = ConsistencyModelPipeline(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 1
inputs["timesteps"] = None
inputs["class_labels"] = 0
image = pipe(**inputs).images
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.5004, 0.5004, 0.4994, 0.5008, 0.4976, 0.5018, 0.4990, 0.4982, 0.4987])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@slow
@require_torch_gpu
class ConsistencyModelPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, seed=0, get_fixed_latents=False, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
generator = torch.manual_seed(seed)
inputs = {
"num_inference_steps": None,
"timesteps": [22, 0],
"class_labels": 0,
"generator": generator,
"output_type": "np",
}
if get_fixed_latents:
latents = self.get_fixed_latents(seed=seed, device=device, dtype=dtype, shape=shape)
inputs["latents"] = latents
return inputs
def get_fixed_latents(self, seed=0, device="cpu", dtype=torch.float32, shape=(1, 3, 64, 64)):
if type(device) == str:
device = torch.device(device)
generator = torch.Generator(device=device).manual_seed(seed)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def test_consistency_model_cd_multistep(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
scheduler = CMStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=80.0,
)
pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device=torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs()
image = pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0888, 0.0881, 0.0666, 0.0479, 0.0292, 0.0195, 0.0201, 0.0163, 0.0254])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
def test_consistency_model_cd_onestep(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
scheduler = CMStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=80.0,
)
pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device=torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs()
inputs["num_inference_steps"] = 1
inputs["timesteps"] = None
image = pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0340, 0.0152, 0.0063, 0.0267, 0.0221, 0.0107, 0.0416, 0.0186, 0.0217])
assert np.abs(image_slice.flatten() - expected_slice).max() < 2e-2
@require_torch_2
def test_consistency_model_cd_multistep_flash_attn(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
scheduler = CMStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=80.0,
)
pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device=torch_device, torch_dtype=torch.float16)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(get_fixed_latents=True, device=torch_device)
# Ensure usage of flash attention in torch 2.0
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
image = pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1875, 0.1428, 0.1289, 0.2151, 0.2092, 0.1477, 0.1877, 0.1641, 0.1353])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
@require_torch_2
def test_consistency_model_cd_onestep_flash_attn(self):
unet = UNet2DModel.from_pretrained("diffusers/consistency_models", subfolder="diffusers_cd_imagenet64_l2")
scheduler = CMStochasticIterativeScheduler(
num_train_timesteps=40,
sigma_min=0.002,
sigma_max=80.0,
)
pipe = ConsistencyModelPipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device=torch_device, torch_dtype=torch.float16)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_inputs(get_fixed_latents=True, device=torch_device)
inputs["num_inference_steps"] = 1
inputs["timesteps"] = None
# Ensure usage of flash attention in torch 2.0
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
image = pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.1663, 0.1948, 0.2275, 0.1680, 0.1204, 0.1245, 0.1858, 0.1338, 0.2095])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

View File

@@ -0,0 +1,150 @@
import torch
from diffusers import CMStochasticIterativeScheduler
from .test_schedulers import SchedulerCommonTest
class CMStochasticIterativeSchedulerTest(SchedulerCommonTest):
scheduler_classes = (CMStochasticIterativeScheduler,)
num_inference_steps = 10
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 201,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
config.update(**kwargs)
return config
# Override test_step_shape to add CMStochasticIterativeScheduler-specific logic regarding timesteps
# Problem is that we don't know two timesteps that will always be in the timestep schedule from only the scheduler
# config; scaled sigma_max is always in the timestep schedule, but sigma_min is in the sigma schedule while scaled
# sigma_min is not in the timestep schedule
def test_step_shape(self):
num_inference_steps = 10
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timestep_0 = scheduler.timesteps[0]
timestep_1 = scheduler.timesteps[1]
sample = self.dummy_sample
residual = 0.1 * sample
output_0 = scheduler.step(residual, timestep_0, sample).prev_sample
output_1 = scheduler.step(residual, timestep_1, sample).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [10, 50, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_clip_denoised(self):
for clip_denoised in [True, False]:
self.check_over_configs(clip_denoised=clip_denoised)
def test_full_loop_no_noise_onestep(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 1
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
for i, t in enumerate(timesteps):
# 1. scale model input
scaled_sample = scheduler.scale_model_input(sample, t)
# 2. predict noise residual
residual = model(scaled_sample, t)
# 3. predict previous sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 192.7614) < 1e-2
assert abs(result_mean.item() - 0.2510) < 1e-3
def test_full_loop_no_noise_multistep(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [106, 0]
scheduler.set_timesteps(timesteps=timesteps)
timesteps = scheduler.timesteps
generator = torch.manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
for t in timesteps:
# 1. scale model input
scaled_sample = scheduler.scale_model_input(sample, t)
# 2. predict noise residual
residual = model(scaled_sample, t)
# 3. predict previous sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 347.6357) < 1e-2
assert abs(result_mean.item() - 0.4527) < 1e-3
def test_custom_timesteps_increasing_order(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [39, 30, 12, 15, 0]
with self.assertRaises(ValueError, msg="`timesteps` must be in descending order."):
scheduler.set_timesteps(timesteps=timesteps)
def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [39, 30, 12, 1, 0]
num_inference_steps = len(timesteps)
with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `timesteps`."):
scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps)
def test_custom_timesteps_too_large(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = [scheduler.config.num_train_timesteps]
with self.assertRaises(
ValueError,
msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}",
):
scheduler.set_timesteps(timesteps=timesteps)

View File

@@ -24,6 +24,7 @@ import torch
import diffusers
from diffusers import (
CMStochasticIterativeScheduler,
DDIMScheduler,
DEISMultistepScheduler,
DiffusionPipeline,
@@ -303,6 +304,11 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
if scheduler_class == CMStochasticIterativeScheduler:
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
time_step = scaled_sigma_max
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
@@ -323,7 +329,11 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
# Make sure `scale_model_input` is invoked to prevent a warning
if scheduler_class != VQDiffusionScheduler:
if scheduler_class == CMStochasticIterativeScheduler:
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
_ = scheduler.scale_model_input(sample, scaled_sigma_max)
_ = new_scheduler.scale_model_input(sample, scaled_sigma_max)
elif scheduler_class != VQDiffusionScheduler:
_ = scheduler.scale_model_input(sample, 0)
_ = new_scheduler.scale_model_input(sample, 0)
@@ -393,6 +403,10 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
if scheduler_class == CMStochasticIterativeScheduler:
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
timestep = scheduler.sigma_to_t(scheduler.config.sigma_max)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
@@ -539,6 +553,10 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
if scheduler_class == CMStochasticIterativeScheduler:
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
timestep = scheduler.sigma_to_t(scheduler.config.sigma_max)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
@@ -594,7 +612,12 @@ class SchedulerCommonTest(unittest.TestCase):
if scheduler_class != VQDiffusionScheduler:
sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0)
if scheduler_class == CMStochasticIterativeScheduler:
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
else:
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
def test_add_noise_device(self):
@@ -606,7 +629,12 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.set_timesteps(100)
sample = self.dummy_sample.to(torch_device)
scaled_sample = scheduler.scale_model_input(sample, 0.0)
if scheduler_class == CMStochasticIterativeScheduler:
# Get valid timestep based on sigma_max, which should always be in timestep schedule.
scaled_sigma_max = scheduler.sigma_to_t(scheduler.config.sigma_max)
scaled_sample = scheduler.scale_model_input(sample, scaled_sigma_max)
else:
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
noise = torch.randn_like(scaled_sample).to(torch_device)
@@ -637,7 +665,7 @@ class SchedulerCommonTest(unittest.TestCase):
def test_trained_betas(self):
for scheduler_class in self.scheduler_classes:
if scheduler_class == VQDiffusionScheduler:
if scheduler_class in (VQDiffusionScheduler, CMStochasticIterativeScheduler):
continue
scheduler_config = self.get_scheduler_config()