mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -24,7 +24,6 @@ import torch
|
||||
from omegaconf import OmegaConf
|
||||
from safetensors.torch import load_file as safe_load
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
@@ -35,7 +34,7 @@ from transformers import (
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ..models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel
|
||||
from ..models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
|
||||
from ..pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from ..pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
@@ -240,6 +239,7 @@ def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwa
|
||||
return image_size
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
@@ -250,6 +250,7 @@ def shave_segments(path, n_shave_prefix_segments=1):
|
||||
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_resnet_paths
|
||||
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
@@ -272,6 +273,7 @@ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths
|
||||
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside resnets to the new naming scheme (local renaming)
|
||||
@@ -288,6 +290,7 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_attention_paths
|
||||
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
@@ -309,6 +312,7 @@ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths
|
||||
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
"""
|
||||
Updates paths inside attentions to the new naming scheme (local renaming)
|
||||
@@ -339,6 +343,7 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
|
||||
def assign_to_checkpoint(
|
||||
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
||||
):
|
||||
@@ -394,6 +399,7 @@ def assign_to_checkpoint(
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
@@ -406,6 +412,7 @@ def conv_attn_to_linear(checkpoint):
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_unet_diffusers_config
|
||||
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
@@ -510,6 +517,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
return config
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
|
||||
def create_vae_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
@@ -534,6 +542,7 @@ def create_vae_diffusers_config(original_config, image_size: int):
|
||||
return config
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint
|
||||
def convert_ldm_unet_checkpoint(
|
||||
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
|
||||
):
|
||||
@@ -782,6 +791,7 @@ def convert_ldm_unet_checkpoint(
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
@@ -889,6 +899,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_bert_checkpoint
|
||||
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||
@@ -939,6 +950,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
return hf_model
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_clip_checkpoint
|
||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
||||
if text_encoder is None:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
@@ -978,56 +990,7 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
|
||||
return text_model
|
||||
|
||||
|
||||
def convert_controlnet_checkpoint(
|
||||
checkpoint,
|
||||
original_config,
|
||||
checkpoint_path,
|
||||
image_size,
|
||||
upcast_attention,
|
||||
extract_ema,
|
||||
use_linear_projection=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
ctrlnet_config.pop("sample_size")
|
||||
|
||||
if use_linear_projection is not None:
|
||||
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
controlnet = ControlNetModel(**ctrlnet_config)
|
||||
|
||||
# Some controlnet ckpt files are distributed independently from the rest of the
|
||||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
||||
if "time_embed.0.weight" in checkpoint:
|
||||
skip_extract_state_dict = True
|
||||
else:
|
||||
skip_extract_state_dict = False
|
||||
|
||||
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint,
|
||||
ctrlnet_config,
|
||||
path=checkpoint_path,
|
||||
extract_ema=extract_ema,
|
||||
controlnet=True,
|
||||
skip_extract_state_dict=skip_extract_state_dict,
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_ctrl_checkpoint.items():
|
||||
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
|
||||
else:
|
||||
controlnet.load_state_dict(converted_ctrl_checkpoint)
|
||||
|
||||
return controlnet
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_open_clip_checkpoint
|
||||
def convert_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
@@ -1312,28 +1275,6 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
|
||||
return {"unet": unet}
|
||||
|
||||
|
||||
def create_controlnet_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs
|
||||
):
|
||||
if "control_stage_config" not in original_config.model.params:
|
||||
raise ValueError("Config does not have controlnet information")
|
||||
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
extract_ema = kwargs.get("extract_ema", False)
|
||||
upcast_attention = kwargs.get("upcast_attention", False)
|
||||
|
||||
controlnet = convert_controlnet_checkpoint(
|
||||
checkpoint,
|
||||
original_config,
|
||||
path,
|
||||
image_size,
|
||||
upcast_attention,
|
||||
extract_ema,
|
||||
)
|
||||
|
||||
return {"controlnet": controlnet}
|
||||
|
||||
|
||||
def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
|
||||
|
||||
@@ -1601,32 +1542,3 @@ def create_stable_unclip_components(
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def create_paint_by_example_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
):
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
image_encoder = convert_paint_by_example_checkpoint(checkpoint)
|
||||
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
|
||||
try:
|
||||
config_name = "CompVis/stable-diffusion-safety-checker"
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
|
||||
)
|
||||
|
||||
return {
|
||||
"image_encoder": image_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user