mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* initial commit * Improve consistency models sampling implementation. * Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling. * Add Unet blocks for consistency models * Add conversion script for Unet * Fix bug in new unet blocks * Fix attention weight loading * Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests. * make style * Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints. * Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script. * make style * Change num_class_embeds to 1000 to better match the original consistency models implementation. * Add support for distillation in pipeline_consistency_models.py. * Improve consistency model tests: - Get small testing checkpoints from hub - Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline - Add onestep, multistep tests for distillation and distillation + class conditional - Add expected image slices for onestep tests * make style * Improve ConsistencyModelPipeline: - Add initial support for class-conditional generation - Fix initial sigma for onestep generation - Fix some sigma shape issues * make style * Improve ConsistencyModelPipeline: - add latents __call__ argument and prepare_latents method - add check_inputs method - add initial docstrings for ConsistencyModelPipeline.__call__ * make style * Fix bug when randomly generating class labels for class-conditional generation. * Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests. * Remove some unused code and make style. * Fix small bug in CMStochasticIterativeScheduler. * Add expected slices for multistep sampling tests and make them pass. * Work on consistency model fast tests: - in pipeline, call self.scheduler.scale_model_input before denoising - get expected slices for Euler and Heun scheduler tests - make Euler test pass - mark Heun test as expected fail because it doesn't support prediction_type "sample" yet - remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas * make style * Refactor conversion script to make it easier to add more model architectures to convert in the future. * Work on ConsistencyModelPipeline tests: - Fix device bug when handling class labels in ConsistencyModelPipeline.__call__ - Add slow tests for onestep and multistep sampling and make them pass - Refactor fast tests - Refactor ConsistencyModelPipeline.__init__ * make style * Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now. * Run python utils/check_copies.py --fix_and_overwrite python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler. * Make fast tests from PipelineTesterMixin pass. * make style * Refactor consistency models pipeline and scheduler: - Remove support for Karras schedulers (only support CMStochasticIterativeScheduler) - Move sigma manipulation, input scaling, denoising from pipeline to scheduler - Make corresponding changes to tests and ensure they pass * make style * Add docstrings and further refactor pipeline and scheduler. * make style * Add initial version of the consistency models documentation. * Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging. * make style * Convert current slow tests to use fp16 and flash attention. * make style * Add slow tests for normal attention on cuda device. * make style * Fix attention weights loading * Update consistency model fast tests for new test checkpoints with attention fix. * make style * apply suggestions * Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler). * Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers. * When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence). * make style * Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example. * apply suggestions from review * make style * fix attention naming * Add tests for CMStochasticIterativeScheduler. * make style * Make CMStochasticIterativeScheduler tests pass. * make style * Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest. * make style * rename some models * Improve API * rename some models * Remove duplicated block * Add docstring and make torch compile work * More fixes * Fixes * Apply suggestions from code review * Apply suggestions from code review * add more docstring * update consistency conversion script --------- Co-authored-by: ayushmangal <ayushmangal@microsoft.com> Co-authored-by: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
314 lines
12 KiB
Python
314 lines
12 KiB
Python
import argparse
|
|
import os
|
|
|
|
import torch
|
|
|
|
from diffusers import (
|
|
CMStochasticIterativeScheduler,
|
|
ConsistencyModelPipeline,
|
|
UNet2DModel,
|
|
)
|
|
|
|
|
|
TEST_UNET_CONFIG = {
|
|
"sample_size": 32,
|
|
"in_channels": 3,
|
|
"out_channels": 3,
|
|
"layers_per_block": 2,
|
|
"num_class_embeds": 1000,
|
|
"block_out_channels": [32, 64],
|
|
"attention_head_dim": 8,
|
|
"down_block_types": [
|
|
"ResnetDownsampleBlock2D",
|
|
"AttnDownBlock2D",
|
|
],
|
|
"up_block_types": [
|
|
"AttnUpBlock2D",
|
|
"ResnetUpsampleBlock2D",
|
|
],
|
|
"resnet_time_scale_shift": "scale_shift",
|
|
"upsample_type": "resnet",
|
|
"downsample_type": "resnet",
|
|
}
|
|
|
|
IMAGENET_64_UNET_CONFIG = {
|
|
"sample_size": 64,
|
|
"in_channels": 3,
|
|
"out_channels": 3,
|
|
"layers_per_block": 3,
|
|
"num_class_embeds": 1000,
|
|
"block_out_channels": [192, 192 * 2, 192 * 3, 192 * 4],
|
|
"attention_head_dim": 64,
|
|
"down_block_types": [
|
|
"ResnetDownsampleBlock2D",
|
|
"AttnDownBlock2D",
|
|
"AttnDownBlock2D",
|
|
"AttnDownBlock2D",
|
|
],
|
|
"up_block_types": [
|
|
"AttnUpBlock2D",
|
|
"AttnUpBlock2D",
|
|
"AttnUpBlock2D",
|
|
"ResnetUpsampleBlock2D",
|
|
],
|
|
"resnet_time_scale_shift": "scale_shift",
|
|
"upsample_type": "resnet",
|
|
"downsample_type": "resnet",
|
|
}
|
|
|
|
LSUN_256_UNET_CONFIG = {
|
|
"sample_size": 256,
|
|
"in_channels": 3,
|
|
"out_channels": 3,
|
|
"layers_per_block": 2,
|
|
"num_class_embeds": None,
|
|
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
|
|
"attention_head_dim": 64,
|
|
"down_block_types": [
|
|
"ResnetDownsampleBlock2D",
|
|
"ResnetDownsampleBlock2D",
|
|
"ResnetDownsampleBlock2D",
|
|
"AttnDownBlock2D",
|
|
"AttnDownBlock2D",
|
|
"AttnDownBlock2D",
|
|
],
|
|
"up_block_types": [
|
|
"AttnUpBlock2D",
|
|
"AttnUpBlock2D",
|
|
"AttnUpBlock2D",
|
|
"ResnetUpsampleBlock2D",
|
|
"ResnetUpsampleBlock2D",
|
|
"ResnetUpsampleBlock2D",
|
|
],
|
|
"resnet_time_scale_shift": "default",
|
|
"upsample_type": "resnet",
|
|
"downsample_type": "resnet",
|
|
}
|
|
|
|
CD_SCHEDULER_CONFIG = {
|
|
"num_train_timesteps": 40,
|
|
"sigma_min": 0.002,
|
|
"sigma_max": 80.0,
|
|
}
|
|
|
|
CT_IMAGENET_64_SCHEDULER_CONFIG = {
|
|
"num_train_timesteps": 201,
|
|
"sigma_min": 0.002,
|
|
"sigma_max": 80.0,
|
|
}
|
|
|
|
CT_LSUN_256_SCHEDULER_CONFIG = {
|
|
"num_train_timesteps": 151,
|
|
"sigma_min": 0.002,
|
|
"sigma_max": 80.0,
|
|
}
|
|
|
|
|
|
def str2bool(v):
|
|
"""
|
|
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
|
"""
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ("yes", "true", "t", "y", "1"):
|
|
return True
|
|
elif v.lower() in ("no", "false", "f", "n", "0"):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError("boolean value expected")
|
|
|
|
|
|
def convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=False):
|
|
new_checkpoint[f"{new_prefix}.norm1.weight"] = checkpoint[f"{old_prefix}.in_layers.0.weight"]
|
|
new_checkpoint[f"{new_prefix}.norm1.bias"] = checkpoint[f"{old_prefix}.in_layers.0.bias"]
|
|
new_checkpoint[f"{new_prefix}.conv1.weight"] = checkpoint[f"{old_prefix}.in_layers.2.weight"]
|
|
new_checkpoint[f"{new_prefix}.conv1.bias"] = checkpoint[f"{old_prefix}.in_layers.2.bias"]
|
|
new_checkpoint[f"{new_prefix}.time_emb_proj.weight"] = checkpoint[f"{old_prefix}.emb_layers.1.weight"]
|
|
new_checkpoint[f"{new_prefix}.time_emb_proj.bias"] = checkpoint[f"{old_prefix}.emb_layers.1.bias"]
|
|
new_checkpoint[f"{new_prefix}.norm2.weight"] = checkpoint[f"{old_prefix}.out_layers.0.weight"]
|
|
new_checkpoint[f"{new_prefix}.norm2.bias"] = checkpoint[f"{old_prefix}.out_layers.0.bias"]
|
|
new_checkpoint[f"{new_prefix}.conv2.weight"] = checkpoint[f"{old_prefix}.out_layers.3.weight"]
|
|
new_checkpoint[f"{new_prefix}.conv2.bias"] = checkpoint[f"{old_prefix}.out_layers.3.bias"]
|
|
|
|
if has_skip:
|
|
new_checkpoint[f"{new_prefix}.conv_shortcut.weight"] = checkpoint[f"{old_prefix}.skip_connection.weight"]
|
|
new_checkpoint[f"{new_prefix}.conv_shortcut.bias"] = checkpoint[f"{old_prefix}.skip_connection.bias"]
|
|
|
|
return new_checkpoint
|
|
|
|
|
|
def convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_dim=None):
|
|
weight_q, weight_k, weight_v = checkpoint[f"{old_prefix}.qkv.weight"].chunk(3, dim=0)
|
|
bias_q, bias_k, bias_v = checkpoint[f"{old_prefix}.qkv.bias"].chunk(3, dim=0)
|
|
|
|
new_checkpoint[f"{new_prefix}.group_norm.weight"] = checkpoint[f"{old_prefix}.norm.weight"]
|
|
new_checkpoint[f"{new_prefix}.group_norm.bias"] = checkpoint[f"{old_prefix}.norm.bias"]
|
|
|
|
new_checkpoint[f"{new_prefix}.to_q.weight"] = weight_q.squeeze(-1).squeeze(-1)
|
|
new_checkpoint[f"{new_prefix}.to_q.bias"] = bias_q.squeeze(-1).squeeze(-1)
|
|
new_checkpoint[f"{new_prefix}.to_k.weight"] = weight_k.squeeze(-1).squeeze(-1)
|
|
new_checkpoint[f"{new_prefix}.to_k.bias"] = bias_k.squeeze(-1).squeeze(-1)
|
|
new_checkpoint[f"{new_prefix}.to_v.weight"] = weight_v.squeeze(-1).squeeze(-1)
|
|
new_checkpoint[f"{new_prefix}.to_v.bias"] = bias_v.squeeze(-1).squeeze(-1)
|
|
|
|
new_checkpoint[f"{new_prefix}.to_out.0.weight"] = (
|
|
checkpoint[f"{old_prefix}.proj_out.weight"].squeeze(-1).squeeze(-1)
|
|
)
|
|
new_checkpoint[f"{new_prefix}.to_out.0.bias"] = checkpoint[f"{old_prefix}.proj_out.bias"].squeeze(-1).squeeze(-1)
|
|
|
|
return new_checkpoint
|
|
|
|
|
|
def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
new_checkpoint = {}
|
|
|
|
new_checkpoint["time_embedding.linear_1.weight"] = checkpoint["time_embed.0.weight"]
|
|
new_checkpoint["time_embedding.linear_1.bias"] = checkpoint["time_embed.0.bias"]
|
|
new_checkpoint["time_embedding.linear_2.weight"] = checkpoint["time_embed.2.weight"]
|
|
new_checkpoint["time_embedding.linear_2.bias"] = checkpoint["time_embed.2.bias"]
|
|
|
|
if unet_config["num_class_embeds"] is not None:
|
|
new_checkpoint["class_embedding.weight"] = checkpoint["label_emb.weight"]
|
|
|
|
new_checkpoint["conv_in.weight"] = checkpoint["input_blocks.0.0.weight"]
|
|
new_checkpoint["conv_in.bias"] = checkpoint["input_blocks.0.0.bias"]
|
|
|
|
down_block_types = unet_config["down_block_types"]
|
|
layers_per_block = unet_config["layers_per_block"]
|
|
attention_head_dim = unet_config["attention_head_dim"]
|
|
channels_list = unet_config["block_out_channels"]
|
|
current_layer = 1
|
|
prev_channels = channels_list[0]
|
|
|
|
for i, layer_type in enumerate(down_block_types):
|
|
current_channels = channels_list[i]
|
|
downsample_block_has_skip = current_channels != prev_channels
|
|
if layer_type == "ResnetDownsampleBlock2D":
|
|
for j in range(layers_per_block):
|
|
new_prefix = f"down_blocks.{i}.resnets.{j}"
|
|
old_prefix = f"input_blocks.{current_layer}.0"
|
|
has_skip = True if j == 0 and downsample_block_has_skip else False
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
|
|
current_layer += 1
|
|
|
|
elif layer_type == "AttnDownBlock2D":
|
|
for j in range(layers_per_block):
|
|
new_prefix = f"down_blocks.{i}.resnets.{j}"
|
|
old_prefix = f"input_blocks.{current_layer}.0"
|
|
has_skip = True if j == 0 and downsample_block_has_skip else False
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=has_skip)
|
|
new_prefix = f"down_blocks.{i}.attentions.{j}"
|
|
old_prefix = f"input_blocks.{current_layer}.1"
|
|
new_checkpoint = convert_attention(
|
|
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
|
|
)
|
|
current_layer += 1
|
|
|
|
if i != len(down_block_types) - 1:
|
|
new_prefix = f"down_blocks.{i}.downsamplers.0"
|
|
old_prefix = f"input_blocks.{current_layer}.0"
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
|
current_layer += 1
|
|
|
|
prev_channels = current_channels
|
|
|
|
# hardcoded the mid-block for now
|
|
new_prefix = "mid_block.resnets.0"
|
|
old_prefix = "middle_block.0"
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
|
new_prefix = "mid_block.attentions.0"
|
|
old_prefix = "middle_block.1"
|
|
new_checkpoint = convert_attention(checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim)
|
|
new_prefix = "mid_block.resnets.1"
|
|
old_prefix = "middle_block.2"
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
|
|
|
current_layer = 0
|
|
up_block_types = unet_config["up_block_types"]
|
|
|
|
for i, layer_type in enumerate(up_block_types):
|
|
if layer_type == "ResnetUpsampleBlock2D":
|
|
for j in range(layers_per_block + 1):
|
|
new_prefix = f"up_blocks.{i}.resnets.{j}"
|
|
old_prefix = f"output_blocks.{current_layer}.0"
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
|
|
current_layer += 1
|
|
|
|
if i != len(up_block_types) - 1:
|
|
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
|
old_prefix = f"output_blocks.{current_layer-1}.1"
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
|
elif layer_type == "AttnUpBlock2D":
|
|
for j in range(layers_per_block + 1):
|
|
new_prefix = f"up_blocks.{i}.resnets.{j}"
|
|
old_prefix = f"output_blocks.{current_layer}.0"
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix, has_skip=True)
|
|
new_prefix = f"up_blocks.{i}.attentions.{j}"
|
|
old_prefix = f"output_blocks.{current_layer}.1"
|
|
new_checkpoint = convert_attention(
|
|
checkpoint, new_checkpoint, old_prefix, new_prefix, attention_head_dim
|
|
)
|
|
current_layer += 1
|
|
|
|
if i != len(up_block_types) - 1:
|
|
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
|
old_prefix = f"output_blocks.{current_layer-1}.2"
|
|
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
|
|
|
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
|
new_checkpoint["conv_norm_out.bias"] = checkpoint["out.0.bias"]
|
|
new_checkpoint["conv_out.weight"] = checkpoint["out.2.weight"]
|
|
new_checkpoint["conv_out.bias"] = checkpoint["out.2.bias"]
|
|
|
|
return new_checkpoint
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
|
|
parser.add_argument(
|
|
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
|
|
)
|
|
parser.add_argument("--class_cond", default=True, type=str, help="Whether the model is class-conditional.")
|
|
|
|
args = parser.parse_args()
|
|
args.class_cond = str2bool(args.class_cond)
|
|
|
|
ckpt_name = os.path.basename(args.unet_path)
|
|
print(f"Checkpoint: {ckpt_name}")
|
|
|
|
# Get U-Net config
|
|
if "imagenet64" in ckpt_name:
|
|
unet_config = IMAGENET_64_UNET_CONFIG
|
|
elif "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
|
|
unet_config = LSUN_256_UNET_CONFIG
|
|
elif "test" in ckpt_name:
|
|
unet_config = TEST_UNET_CONFIG
|
|
else:
|
|
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
|
|
|
|
if not args.class_cond:
|
|
unet_config["num_class_embeds"] = None
|
|
|
|
converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
|
|
|
|
image_unet = UNet2DModel(**unet_config)
|
|
image_unet.load_state_dict(converted_unet_ckpt)
|
|
|
|
# Get scheduler config
|
|
if "cd" in ckpt_name or "test" in ckpt_name:
|
|
scheduler_config = CD_SCHEDULER_CONFIG
|
|
elif "ct" in ckpt_name and "imagenet64" in ckpt_name:
|
|
scheduler_config = CT_IMAGENET_64_SCHEDULER_CONFIG
|
|
elif "ct" in ckpt_name and "256" in ckpt_name and (("bedroom" in ckpt_name) or ("cat" in ckpt_name)):
|
|
scheduler_config = CT_LSUN_256_SCHEDULER_CONFIG
|
|
else:
|
|
raise ValueError(f"Checkpoint type {ckpt_name} is not currently supported.")
|
|
|
|
cm_scheduler = CMStochasticIterativeScheduler(**scheduler_config)
|
|
|
|
consistency_model = ConsistencyModelPipeline(unet=image_unet, scheduler=cm_scheduler)
|
|
consistency_model.save_pretrained(args.dump_path)
|