1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/scripts/convert_consistency_to_diffusers.py
dg845 aed7499a8d Add Consistency Models Pipeline (#3492)
* initial commit

* Improve consistency models sampling implementation.

* Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.

* Add Unet blocks for consistency models

* Add conversion script for Unet

* Fix bug in new unet blocks

* Fix attention weight loading

* Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.

* make style

* Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.

* Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.

* make style

* Change num_class_embeds to 1000 to better match the original consistency models implementation.

* Add support for distillation in pipeline_consistency_models.py.

* Improve consistency model tests:
	- Get small testing checkpoints from hub
	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
	- Add onestep, multistep tests for distillation and distillation + class conditional
	- Add expected image slices for onestep tests

* make style

* Improve ConsistencyModelPipeline:
	- Add initial support for class-conditional generation
	- Fix initial sigma for onestep generation
	- Fix some sigma shape issues

* make style

* Improve ConsistencyModelPipeline:
	- add latents __call__ argument and prepare_latents method
	- add check_inputs method
	- add initial docstrings for ConsistencyModelPipeline.__call__

* make style

* Fix bug when randomly generating class labels for class-conditional generation.

* Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.

* Remove some unused code and make style.

* Fix small bug in CMStochasticIterativeScheduler.

* Add expected slices for multistep sampling tests and make them pass.

* Work on consistency model fast tests:
	- in pipeline, call self.scheduler.scale_model_input before denoising
	- get expected slices for Euler and Heun scheduler tests
	- make Euler test pass
	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas

* make style

* Refactor conversion script to make it easier to add more model architectures to convert in the future.

* Work on ConsistencyModelPipeline tests:
	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
	- Add slow tests for onestep and multistep sampling and make them pass
	- Refactor fast tests
	- Refactor ConsistencyModelPipeline.__init__

* make style

* Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.

* Run python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.

* Make fast tests from PipelineTesterMixin pass.

* make style

* Refactor consistency models pipeline and scheduler:
	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
	- Make corresponding changes to tests and ensure they pass

* make style

* Add docstrings and further refactor pipeline and scheduler.

* make style

* Add initial version of the consistency models documentation.

* Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.

* make style

* Convert current slow tests to use fp16 and flash attention.

* make style

* Add slow tests for normal attention on cuda device.

* make style

* Fix attention weights loading

* Update consistency model fast tests for new test checkpoints with attention fix.

* make style

* apply suggestions

* Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).

* Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.

* When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).

* make style

* Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.

* apply suggestions from review

* make style

* fix attention naming

* Add tests for CMStochasticIterativeScheduler.

* make style

* Make CMStochasticIterativeScheduler tests pass.

* make style

* Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.

* make style

* rename some models

* Improve API

* rename some models

* Remove duplicated block

* Add docstring and make torch compile work

* More fixes

* Fixes

* Apply suggestions from code review

* Apply suggestions from code review

* add more docstring

* update consistency conversion script

---------

Co-authored-by: ayushmangal <ayushmangal@microsoft.com>
Co-authored-by: Ayush Mangal <43698245+ayushtues@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-07-05 19:33:58 +02:00

314 lines
12 KiB
Python

import argparse
import os
import torch
from diffusers import (
CMStochasticIterativeScheduler,
ConsistencyModelPipeline,
UNet2DModel,
)
TEST_UNET_CONFIG = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 2,
"num_class_embeds": 1000,
"block_out_channels": [32, 64],
"attention_head_dim": 8,
"down_block_types": [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"upsample_type": "resnet",
"downsample_type": "resnet",
}
IMAGENET_64_UNET_CONFIG = {
"sample_size": 64,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 3,
"num_class_embeds": 1000,
"block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"upsample_type": "resnet",
"downsample_type": "resnet",
}
LSUN_256_UNET_CONFIG = {
"sample_size": 256,
"in_channels": 3,
"out_channels": 3,
"layers_per_block": 2,
"num_class_embeds": None,
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "default",
"upsample_type": "resnet",
"downsample_type": "resnet",
}
CD_SCHEDULER_CONFIG = {
"num_train_timesteps": 40,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
CT_IMAGENET_64_SCHEDULER_CONFIG = {
"num_train_timesteps": 201,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
CT_LSUN_256_SCHEDULER_CONFIG = {
"num_train_timesteps": 151,
"sigma_min": 0.002,
"sigma_max": 80.0,
}
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
if has_skip:
new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
return new_checkpoint
def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
)
new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
return new_checkpoint
def con_pt_to_diffuser(checkpoint_path: str, unet_config):
checkpoint = torch.load(checkpoint_path, map_location="cpu")
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
if unet_config["num_class_embeds"] is not None:
new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
down_block_types = unet_config["down_block_types"]
layers_per_block = unet_config["layers_per_block"]
attention_head_dim = unet_config["attention_head_dim"]
channels_list = unet_config["block_out_channels"]
current_layer = 1
prev_channels = channels_list[0]
for i, layer_type in enumerate(down_block_types):
current_channels = channels_list[i]
downsample_block_has_skip = current_channels != prev_channels
if layer_type == "ResnetDownsampleBlock2D":
for j in range(layers_per_block):
new_prefix = f"down_blocks.{i}.resnets.{j}"
old_prefix = f"input_blocks.{current_layer}.0"
has_skip = True if j == 0 and downsample_block_has_skip else False
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
current_layer += 1
elif layer_type == "AttnDownBlock2D":
for j in range(layers_per_block):
new_prefix = f"down_blocks.{i}.resnets.{j}"
old_prefix = f"input_blocks.{current_layer}.0"
has_skip = True if j == 0 and downsample_block_has_skip else False
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
new_prefix = f"down_blocks.{i}.attentions.{j}"
old_prefix = f"input_blocks.{current_layer}.1"
new_checkpoint = convert_attention(
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
)
current_layer += 1
if i != len(down_block_types) - 1:
new_prefix = f"down_blocks.{i}.downsamplers.0"
old_prefix = f"input_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
current_layer += 1
prev_channels = current_channels
# hardcoded the mid-block for now
new_prefix = "mid_block.resnets.0"
old_prefix = "middle_block.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_prefix = "mid_block.attentions.0"
old_prefix = "middle_block.1"
new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
new_prefix = "mid_block.resnets.1"
old_prefix = "middle_block.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
current_layer = 0
up_block_types = unet_config["up_block_types"]
for i, layer_type in enumerate(up_block_types):
if layer_type == "ResnetUpsampleBlock2D":
for j in range(layers_per_block + 1):
new_prefix = f"up_blocks.{i}.resnets.{j}"
old_prefix = f"output_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
current_layer += 1
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.1"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
elif layer_type == "AttnUpBlock2D":
for j in range(layers_per_block + 1):
new_prefix = f"up_blocks.{i}.resnets.{j}"
old_prefix = f"output_blocks.{current_layer}.0"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
new_prefix = f"up_blocks.{i}.attentions.{j}"
old_prefix = f"output_blocks.{current_layer}.1"
new_checkpoint = convert_attention(
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
)
current_layer += 1
if i != len(up_block_types) - 1:
new_prefix = f"up_blocks.{i}.upsamplers.0"
old_prefix = f"output_blocks.{current_layer-1}.2"
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
return new_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
)
parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
args = parser.parse_args()
args.class_cond = str2bool(args.class_cond)
ckpt_name = os.path.basename(args.unet_path)
print(f"Checkpoint: {ckpt_name}")
# Get U-Net config
if "imagenet64" in ckpt_name:
unet_config = IMAGENET_64_UNET_CONFIG
elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
unet_config = LSUN_256_UNET_CONFIG
elif "test" in ckpt_name:
unet_config = TEST_UNET_CONFIG
else:
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
if not args.class_cond:
unet_config["num_class_embeds"] = None
converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
image_unet = UNet2DModel(**unet_config)
image_unet.load_state_dict(converted_unet_ckpt)
# Get scheduler config
if "cd" in ckpt_name or "test" in ckpt_name:
scheduler_config = CD_SCHEDULER_CONFIG
elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
else:
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
consistency_model.save_pretrained(args.dump_path)