1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-01-18 15:08:10 +00:00
parent eb71c80448
commit 32349c5ba5
2 changed files with 166 additions and 147 deletions

View File

@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from transformers import AutoFeatureExtractor
from ..models.modeling_utils import load_state_dict
from ..pipelines.pipeline_utils import _get_pipeline_class
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ..utils import (
is_accelerate_available,
@@ -27,11 +28,11 @@ from ..utils import (
)
from ..utils.hub_utils import _get_model_file
from .single_file_utils import (
create_controlnet_model,
create_scheduler,
create_text_encoders_and_tokenizers,
create_unet_model,
create_vae_model,
create_diffusers_controlnet_model_from_ldm,
create_diffusers_unet_model_from_ldm,
create_diffusers_vae_model_from_ldm,
create_scheduler_from_ldm,
create_text_encoders_and_tokenizers_from_ldm,
fetch_original_config,
infer_model_type,
)
@@ -96,46 +97,57 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
return repo_id, weights_name
def build_component(
def build_sub_model_components(
pipeline_components,
pipeline_class_name,
component_name,
original_config,
checkpoint,
checkpoint_path_or_dict,
local_files_only=False,
load_safety_checker=False,
**kwargs,
):
if component_name in kwargs:
component = kwargs.pop(component_name, None)
return {component_name: component}
if component_name in pipeline_components:
return {}
load_safety_checker = kwargs.get("load_safety_checker", False)
local_files_only = kwargs.get("local_files_only", False)
model_type = kwargs.get("model_type", None)
image_size = kwargs.pop("image_size", None)
if component_name == "unet":
unet_components = create_unet_model(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
num_in_channels = kwargs.pop("num_in_channels", None)
unet_components = create_diffusers_unet_model_from_ldm(
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size
)
return unet_components
if component_name == "vae":
vae_components = create_vae_model(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
vae_components = create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size
)
return vae_components
if component_name == "scheduler":
scheduler_components = create_scheduler(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
scheduler_type = kwargs.get("scheduler_type", "ddim")
prediction_type = kwargs.get("prediction_type", None)
scheduler_components = create_scheduler_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
scheduler_type=scheduler_type,
prediction_type=prediction_type,
model_type=model_type,
)
return scheduler_components
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
text_encoder_components = create_text_encoders_and_tokenizers(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
original_config,
checkpoint,
model_type=model_type,
local_files_only=local_files_only,
)
return text_encoder_components
@@ -156,7 +168,7 @@ def build_component(
return
def build_additional_components(
def set_additional_components(
pipeline_class_name,
original_config,
**kwargs,
@@ -282,36 +294,57 @@ class FromSingleFileMixin:
original_config = fetch_original_config(class_name, checkpoint, original_config_file, config_files)
if class_name == "AutoencoderKL":
component = create_vae_model(class_name, original_config, checkpoint, pretrained_model_link_or_path)
image_size = kwargs.pop("image_size", None)
component = create_diffusers_vae_model_from_ldm(
class_name, original_config, checkpoint, image_size=image_size
)
return component["vae"]
if class_name == "ControlNetModel":
component = create_controlnet_model(class_name, original_config, checkpoint, **kwargs)
upcast_attention = kwargs.pop("upcast_attention", False)
image_size = kwargs.pop("image_size", None)
component = create_diffusers_controlnet_model_from_ldm(
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
)
return component["controlnet"]
component_names = extract_pipeline_component_names(cls)
pipeline_components = {}
for component in component_names:
components = build_component(
pipeline_components,
class_name,
component,
original_config,
checkpoint,
pretrained_model_link_or_path,
**kwargs,
)
if not components:
continue
pipeline_components.update(components)
pipeline_class = _get_pipeline_class(
cls,
config=None,
cache_dir=cache_dir,
)
additional_components = set(component_names - pipeline_components.keys())
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_kwargs = {}
for name in expected_modules:
if name in passed_class_obj:
init_kwargs[name] = passed_class_obj[name]
else:
components = build_sub_model_components(
init_kwargs,
class_name,
name,
original_config,
checkpoint,
pretrained_model_link_or_path,
**kwargs,
)
if not components:
continue
init_kwargs.update(components)
additional_components = set(optional_kwargs - init_kwargs.keys())
if additional_components:
components = build_additional_components(class_name, original_config, **kwargs)
components = set_additional_components(class_name, original_config, **kwargs)
if components:
pipeline_components.update(components)
init_kwargs.update(components)
pipe = cls(**pipeline_components)
init_kwargs.update(passed_pipe_kwargs)
pipe = pipeline_class(**init_kwargs)
if torch_dtype is not None:
pipe.to(dtype=torch_dtype)

View File

@@ -14,7 +14,6 @@
# limitations under the License.
""" Conversion script for the Stable Diffusion checkpoints."""
import re
from contextlib import nullcontext
from io import BytesIO
@@ -188,30 +187,6 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
"cond_stage_model.model.text_projection",
]
textenc_conversion_lst = [
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
("ln_final.weight", "text_model.final_layer_norm.weight"),
("ln_final.bias", "text_model.final_layer_norm.bias"),
("text_projection", "text_projection.weight"),
]
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
textenc_transformer_conversion_lst = [
# (stable-diffusion, HF Diffusers)
("resblocks.", "text_model.encoder.layers."),
("ln_1", "layer_norm1"),
("ln_2", "layer_norm2"),
(".c_fc.", ".fc1."),
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
]
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
def fetch_original_config_file_from_url(class_name, checkpoint):
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
@@ -284,7 +259,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True)
return checkpoint
def infer_model_type(pipeline_class_name, original_config, model_type=None, **kwargs):
def infer_model_type(original_config, model_type=None):
if model_type is not None:
return model_type
@@ -318,10 +293,12 @@ def get_default_scheduler_config():
return SCHEDULER_DEFAULT_CONFIG
def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs):
image_size = kwargs.get("image_size", 512)
def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None):
if image_size:
return image_size
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
model_type = infer_model_type(pipeline_class_name, original_config, **kwargs)
model_type = infer_model_type(original_config, model_type)
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = original_config["model"]["params"].unet_config.params.image_size
@@ -340,7 +317,9 @@ def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwa
image_size = 512 if global_step == 875000 else 768
return image_size
return image_size
else:
image_size = 512
return image_size
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
@@ -526,41 +505,36 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint,
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False):
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if skip_extract_state_dict:
unet_state_dict = checkpoint
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = LDM_UNET_KEY
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning("Checkpoint has both EMA and non-EMA weights.")
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = LDM_UNET_KEY
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
@@ -792,7 +766,10 @@ def convert_controlnet_checkpoint(
return new_checkpoint
def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **kwargs):
def create_diffusers_controlnet_model_from_ldm(
pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None
):
# import here to avoid circular imports
from ..models import ControlNetModel
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
@@ -800,8 +777,7 @@ def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
upcast_attention = kwargs.get("upcast_attention", False)
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
diffusers_config["upcast_attention"] = upcast_attention
@@ -953,7 +929,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
def create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=False):
try:
config = CLIPTextConfig.from_pretrained(LDM_CLIP_CONFIG_NAME, local_files_only=local_files_only)
except Exception:
@@ -988,7 +964,7 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
return text_model
def convert_open_clip_checkpoint(
def create_text_encoder_from_open_clip_checkpoint(
checkpoint,
config_name,
prefix="cond_stage_model.model.",
@@ -1069,36 +1045,35 @@ def convert_open_clip_checkpoint(
return text_model
def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
if "num_in_channels" in kwargs:
num_in_channels = kwargs.get("num_in_channels")
def create_diffusers_unet_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
num_in_channels=None,
upcast_attention=False,
extract_ema=False,
image_size=None,
):
if num_in_channels is None:
if pipeline_class_name in [
"StableDiffusionInpaintPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
]:
num_in_channels = 9
elif pipeline_class_name in [
"StableDiffusionInpaintPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
]:
num_in_channels = 9
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
num_in_channels = 7
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
num_in_channels = 7
else:
num_in_channels = 4
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
upcast_attention = kwargs.get("upcast_attention", False)
extract_ema = kwargs.get("extract_ema", False)
else:
num_in_channels = 4
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema
)
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
unet = UNet2DConditionModel(**unet_config)
@@ -1112,10 +1087,16 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
return {"unet": unet}
def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
def create_diffusers_vae_model_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
image_size=None,
):
# import here to avoid circular imports
from ..models import AutoencoderKL
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
@@ -1133,18 +1114,20 @@ def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoin
return {"vae": vae}
def create_text_encoders_and_tokenizers(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
def create_text_encoders_and_tokenizers_from_ldm(
original_config,
checkpoint,
model_type=None,
local_files_only=False,
):
model_type = infer_model_type(pipeline_class_name, original_config)
local_files_only = kwargs.get("local_files_only", False)
model_type = infer_model_type(original_config, model_type=model_type)
if model_type == "FrozenOpenCLIPEmbedder":
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}
try:
text_encoder = convert_open_clip_checkpoint(
text_encoder = create_text_encoder_from_open_clip_checkpoint(
checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
)
tokenizer = CLIPTokenizer.from_pretrained(
@@ -1160,7 +1143,7 @@ def create_text_encoders_and_tokenizers(
elif model_type == "FrozenCLIPEmbedder":
try:
config_name = "openai/clip-vit-large-patch14"
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
except Exception:
@@ -1177,7 +1160,7 @@ def create_text_encoders_and_tokenizers(
try:
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
text_encoder_2 = convert_open_clip_checkpoint(
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
checkpoint,
config_name,
prefix=prefix,
@@ -1185,8 +1168,7 @@ def create_text_encoders_and_tokenizers(
local_files_only=local_files_only,
**config_kwargs,
)
except Exception as e:
raise e
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
)
@@ -1203,10 +1185,9 @@ def create_text_encoders_and_tokenizers(
try:
config_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
except Exception as e:
raise e
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'."
)
@@ -1216,7 +1197,7 @@ def create_text_encoders_and_tokenizers(
config_kwargs = {"projection_dim": 1280}
prefix = "conditioner.embedders.1.model."
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
text_encoder_2 = convert_open_clip_checkpoint(
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
checkpoint,
config_name,
prefix=prefix,
@@ -1239,12 +1220,17 @@ def create_text_encoders_and_tokenizers(
return
def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
def create_scheduler_from_ldm(
pipeline_class_name,
original_config,
checkpoint,
prediction_type=None,
scheduler_type="ddim",
model_type=None,
):
scheduler_config = get_default_scheduler_config()
model_type = infer_model_type(pipeline_class_name, original_config)
model_type = infer_model_type(original_config, model_type=model_type)
scheduler_type = kwargs.get("scheduler_type", "ddim")
prediction_type = kwargs.get("prediction_type", None)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000