mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -40,6 +40,62 @@ if is_accelerate_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
DIFFUSER_PIPELINE_CONFIGS = {
|
||||
"StableDiffusionPipeline": None,
|
||||
"StableDiffusionImg2ImgPipeline": None,
|
||||
"StableDiffusionInpaintPipeline": None,
|
||||
"StableDiffusionControlNetPipeline": None,
|
||||
}
|
||||
|
||||
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
MODEL_TYPE_FROM_PIPELINE_CLASS = {
|
||||
"StableUnCLIPPipeline": "FrozenOpenCLIPEmbedder",
|
||||
"StableUnCLIPImg2ImgPipeline": "FrozenOpenCLIPEmbedder",
|
||||
}
|
||||
|
||||
|
||||
|
||||
def check_valid_url(pretrained_model_link_or_path):
|
||||
# remove huggingface url
|
||||
has_valid_url_prefix = False
|
||||
for prefix in VALID_URL_PREFIXES:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
|
||||
return has_valid_url_prefix
|
||||
|
||||
|
||||
def fetch_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None):
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def infer_model_type(pipeline_class_name):
|
||||
return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None)
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
@@ -150,12 +206,10 @@ class FromSingleFileMixin:
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -221,43 +275,15 @@ class FromSingleFileMixin:
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
# remove huggingface url
|
||||
has_valid_url_prefix = False
|
||||
valid_url_prefixes = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
for prefix in valid_url_prefixes:
|
||||
if pretrained_model_link_or_path.startswith(prefix):
|
||||
pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :]
|
||||
has_valid_url_prefix = True
|
||||
has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path)
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
if not ckpt_path.is_file():
|
||||
if not has_valid_url_prefix:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(valid_url_prefixes)}"
|
||||
)
|
||||
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
|
||||
if file_path.startswith("blob/"):
|
||||
file_path = file_path[len("blob/") :]
|
||||
|
||||
if file_path.startswith("main/"):
|
||||
file_path = file_path[len("main/") :]
|
||||
|
||||
pretrained_model_link_or_path = hf_hub_download(
|
||||
repo_id,
|
||||
filename=file_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
force_download=force_download,
|
||||
if (not ckpt_path.is_file()) and (not has_valid_url_prefix):
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}"
|
||||
)
|
||||
pretrained_model_link_or_path = fetch_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from typing import Dict, Optional, Union
|
||||
@@ -26,18 +25,12 @@ from safetensors.torch import load_file as safe_load
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionConfig,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
PriorTransformer,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
@@ -54,11 +47,8 @@ from ...schedulers import (
|
||||
)
|
||||
from ...utils import is_accelerate_available, is_omegaconf_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -147,7 +137,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def get_model_type(original_config, model_type=None):
|
||||
def set_model_type(original_config, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -710,6 +700,242 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
||||
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
||||
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
||||
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
||||
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
||||
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
||||
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
||||
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
||||
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
||||
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
||||
|
||||
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
||||
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
||||
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
||||
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
|
||||
paths = renew_vae_resnet_paths(resnets)
|
||||
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
paths = renew_vae_attention_paths(mid_attentions)
|
||||
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||
|
||||
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||
|
||||
def _copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
def _copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||
|
||||
# copy attn
|
||||
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||
|
||||
# copy MLP
|
||||
pt_mlp = pt_layer[1][1]
|
||||
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||
|
||||
def _copy_layers(hf_layers, pt_layers):
|
||||
for i, hf_layer in enumerate(hf_layers):
|
||||
if i != 0:
|
||||
i += i
|
||||
pt_layer = pt_layers[i : i + 2]
|
||||
_copy_layer(hf_layer, pt_layer)
|
||||
|
||||
hf_model = LDMBertModel(config).eval()
|
||||
|
||||
# copy embeds
|
||||
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||
|
||||
# copy layer norm
|
||||
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||
|
||||
# copy hidden layers
|
||||
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||
|
||||
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
||||
if text_encoder is None:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
try:
|
||||
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
text_model = CLIPTextModel(config)
|
||||
else:
|
||||
text_model = text_encoder
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
remove_prefixes = ["cond_stage_model.transformer", "conditioner.embedders.0.transformer"]
|
||||
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
if key.startswith(prefix):
|
||||
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
else:
|
||||
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
||||
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
def create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs):
|
||||
extract_ema = kwargs.get("extract_ema", False)
|
||||
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=path, extract_ema=extract_ema
|
||||
)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in diffusers_format_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
else:
|
||||
unet.load_state_dict(diffusers_format_unet_checkpoint)
|
||||
|
||||
return unet
|
||||
|
||||
|
||||
def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs):
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in diffusers_format_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
else:
|
||||
vae.load_state_dict(diffusers_format_vae_checkpoint)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
original_config_file: str = None,
|
||||
@@ -737,6 +963,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
tokenizer=None,
|
||||
tokenizer_2=None,
|
||||
config_files=None,
|
||||
**kwargs
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
@@ -837,15 +1064,10 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
original_config = fetch_original_config(checkpoint, config_files)
|
||||
model_type = get_model_type(original_config, model_type)
|
||||
model_type = set_model_type(original_config, model_type)
|
||||
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=path, extract_ema=extract_ema
|
||||
)
|
||||
|
||||
num_channels = get_num_channels()
|
||||
unet = create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs)
|
||||
vae = create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs)
|
||||
|
||||
if pipeline_class is None:
|
||||
# Check if we have a SDXL or SD model and initialize default pipeline
|
||||
|
||||
Reference in New Issue
Block a user