From 32349c5ba5ebeaf9f765d89943e08a739751db14 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 18 Jan 2024 15:08:10 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 117 +++++++----- src/diffusers/loaders/single_file_utils.py | 196 ++++++++++----------- 2 files changed, 166 insertions(+), 147 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index fd08b5abc2..74f7a13433 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args from transformers import AutoFeatureExtractor from ..models.modeling_utils import load_state_dict +from ..pipelines.pipeline_utils import _get_pipeline_class from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..utils import ( is_accelerate_available, @@ -27,11 +28,11 @@ from ..utils import ( ) from ..utils.hub_utils import _get_model_file from .single_file_utils import ( - create_controlnet_model, - create_scheduler, - create_text_encoders_and_tokenizers, - create_unet_model, - create_vae_model, + create_diffusers_controlnet_model_from_ldm, + create_diffusers_unet_model_from_ldm, + create_diffusers_vae_model_from_ldm, + create_scheduler_from_ldm, + create_text_encoders_and_tokenizers_from_ldm, fetch_original_config, infer_model_type, ) @@ -96,46 +97,57 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): return repo_id, weights_name -def build_component( +def build_sub_model_components( pipeline_components, pipeline_class_name, component_name, original_config, checkpoint, checkpoint_path_or_dict, + local_files_only=False, + load_safety_checker=False, **kwargs, ): - if component_name in kwargs: - component = kwargs.pop(component_name, None) - return {component_name: component} - if component_name in pipeline_components: return {} - load_safety_checker = kwargs.get("load_safety_checker", False) - local_files_only = kwargs.get("local_files_only", False) + model_type = kwargs.get("model_type", None) + image_size = kwargs.pop("image_size", None) if component_name == "unet": - unet_components = create_unet_model( - pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + num_in_channels = kwargs.pop("num_in_channels", None) + unet_components = create_diffusers_unet_model_from_ldm( + pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size ) return unet_components if component_name == "vae": - vae_components = create_vae_model( - pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + vae_components = create_diffusers_vae_model_from_ldm( + pipeline_class_name, original_config, checkpoint, image_size ) return vae_components if component_name == "scheduler": - scheduler_components = create_scheduler( - pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + scheduler_type = kwargs.get("scheduler_type", "ddim") + prediction_type = kwargs.get("prediction_type", None) + + scheduler_components = create_scheduler_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + scheduler_type=scheduler_type, + prediction_type=prediction_type, + model_type=model_type, ) + return scheduler_components if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: - text_encoder_components = create_text_encoders_and_tokenizers( - pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + text_encoder_components = create_text_encoders_and_tokenizers_from_ldm( + original_config, + checkpoint, + model_type=model_type, + local_files_only=local_files_only, ) return text_encoder_components @@ -156,7 +168,7 @@ def build_component( return -def build_additional_components( +def set_additional_components( pipeline_class_name, original_config, **kwargs, @@ -282,36 +294,57 @@ class FromSingleFileMixin: original_config = fetch_original_config(class_name, checkpoint, original_config_file, config_files) if class_name == "AutoencoderKL": - component = create_vae_model(class_name, original_config, checkpoint, pretrained_model_link_or_path) + image_size = kwargs.pop("image_size", None) + component = create_diffusers_vae_model_from_ldm( + class_name, original_config, checkpoint, image_size=image_size + ) return component["vae"] if class_name == "ControlNetModel": - component = create_controlnet_model(class_name, original_config, checkpoint, **kwargs) + upcast_attention = kwargs.pop("upcast_attention", False) + image_size = kwargs.pop("image_size", None) + + component = create_diffusers_controlnet_model_from_ldm( + class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size + ) return component["controlnet"] - component_names = extract_pipeline_component_names(cls) - pipeline_components = {} - for component in component_names: - components = build_component( - pipeline_components, - class_name, - component, - original_config, - checkpoint, - pretrained_model_link_or_path, - **kwargs, - ) - if not components: - continue - pipeline_components.update(components) + pipeline_class = _get_pipeline_class( + cls, + config=None, + cache_dir=cache_dir, + ) - additional_components = set(component_names - pipeline_components.keys()) + expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) + passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} + passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} + + init_kwargs = {} + for name in expected_modules: + if name in passed_class_obj: + init_kwargs[name] = passed_class_obj[name] + else: + components = build_sub_model_components( + init_kwargs, + class_name, + name, + original_config, + checkpoint, + pretrained_model_link_or_path, + **kwargs, + ) + if not components: + continue + init_kwargs.update(components) + + additional_components = set(optional_kwargs - init_kwargs.keys()) if additional_components: - components = build_additional_components(class_name, original_config, **kwargs) + components = set_additional_components(class_name, original_config, **kwargs) if components: - pipeline_components.update(components) + init_kwargs.update(components) - pipe = cls(**pipeline_components) + init_kwargs.update(passed_pipe_kwargs) + pipe = pipeline_class(**init_kwargs) if torch_dtype is not None: pipe.to(dtype=torch_dtype) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 52e196bd24..cbfd8ea0d4 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -14,7 +14,6 @@ # limitations under the License. """ Conversion script for the Stable Diffusion checkpoints.""" -import re from contextlib import nullcontext from io import BytesIO @@ -188,30 +187,6 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [ "cond_stage_model.model.text_projection", ] -textenc_conversion_lst = [ - ("positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("ln_final.weight", "text_model.final_layer_norm.weight"), - ("ln_final.bias", "text_model.final_layer_norm.bias"), - ("text_projection", "text_projection.weight"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - def fetch_original_config_file_from_url(class_name, checkpoint): if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: @@ -284,7 +259,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True) return checkpoint -def infer_model_type(pipeline_class_name, original_config, model_type=None, **kwargs): +def infer_model_type(original_config, model_type=None): if model_type is not None: return model_type @@ -318,10 +293,12 @@ def get_default_scheduler_config(): return SCHEDULER_DEFAULT_CONFIG -def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs): - image_size = kwargs.get("image_size", 512) +def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None): + if image_size: + return image_size + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None - model_type = infer_model_type(pipeline_class_name, original_config, **kwargs) + model_type = infer_model_type(original_config, model_type) if pipeline_class_name == "StableDiffusionUpscalePipeline": image_size = original_config["model"]["params"].unet_config.params.image_size @@ -340,7 +317,9 @@ def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwa image_size = 512 if global_step == 875000 else 768 return image_size - return image_size + else: + image_size = 512 + return image_size # Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear @@ -526,41 +505,36 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint, new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key) -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False): +def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ - if skip_extract_state_dict: - unet_state_dict = checkpoint + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + unet_key = LDM_UNET_KEY + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + logger.warning("Checkpoint has both EMA and non-EMA weights.") + logger.warning( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) else: - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - unet_key = LDM_UNET_KEY - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.") + if sum(k.startswith("model_ema") for k in keys) > 100: logger.warning( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - logger.warning( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"] @@ -792,7 +766,10 @@ def convert_controlnet_checkpoint( return new_checkpoint -def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **kwargs): +def create_diffusers_controlnet_model_from_ldm( + pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None +): + # import here to avoid circular imports from ..models import ControlNetModel # NOTE: this while loop isn't great but this controlnet checkpoint has one additional @@ -800,8 +777,7 @@ def create_controlnet_model(pipeline_class_name, original_config, checkpoint, ** while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs) - upcast_attention = kwargs.get("upcast_attention", False) + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size) diffusers_config["upcast_attention"] = upcast_attention @@ -953,7 +929,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config): return new_checkpoint -def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False): +def create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=False): try: config = CLIPTextConfig.from_pretrained(LDM_CLIP_CONFIG_NAME, local_files_only=local_files_only) except Exception: @@ -988,7 +964,7 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False): return text_model -def convert_open_clip_checkpoint( +def create_text_encoder_from_open_clip_checkpoint( checkpoint, config_name, prefix="cond_stage_model.model.", @@ -1069,36 +1045,35 @@ def convert_open_clip_checkpoint( return text_model -def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs): - if "num_in_channels" in kwargs: - num_in_channels = kwargs.get("num_in_channels") +def create_diffusers_unet_model_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + num_in_channels=None, + upcast_attention=False, + extract_ema=False, + image_size=None, +): + if num_in_channels is None: + if pipeline_class_name in [ + "StableDiffusionInpaintPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLControlNetInpaintPipeline", + ]: + num_in_channels = 9 - elif pipeline_class_name in [ - "StableDiffusionInpaintPipeline", - "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLControlNetInpaintPipeline", - ]: - num_in_channels = 9 + elif pipeline_class_name == "StableDiffusionUpscalePipeline": + num_in_channels = 7 - elif pipeline_class_name == "StableDiffusionUpscalePipeline": - num_in_channels = 7 - - else: - num_in_channels = 4 - - image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs) - - upcast_attention = kwargs.get("upcast_attention", False) - extract_ema = kwargs.get("extract_ema", False) + else: + num_in_channels = 4 + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["in_channels"] = num_in_channels unet_config["upcast_attention"] = upcast_attention - path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" - diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=path, extract_ema=extract_ema - ) + diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): unet = UNet2DConditionModel(**unet_config) @@ -1112,10 +1087,16 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi return {"unet": unet} -def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs): +def create_diffusers_vae_model_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + image_size=None, +): + # import here to avoid circular imports from ..models import AutoencoderKL - image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs) + image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size) vae_config = create_vae_diffusers_config(original_config, image_size=image_size) diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) @@ -1133,18 +1114,20 @@ def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoin return {"vae": vae} -def create_text_encoders_and_tokenizers( - pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs +def create_text_encoders_and_tokenizers_from_ldm( + original_config, + checkpoint, + model_type=None, + local_files_only=False, ): - model_type = infer_model_type(pipeline_class_name, original_config) - local_files_only = kwargs.get("local_files_only", False) + model_type = infer_model_type(original_config, model_type=model_type) if model_type == "FrozenOpenCLIPEmbedder": config_name = "stabilityai/stable-diffusion-2" config_kwargs = {"subfolder": "text_encoder"} try: - text_encoder = convert_open_clip_checkpoint( + text_encoder = create_text_encoder_from_open_clip_checkpoint( checkpoint, config_name, local_files_only=local_files_only, **config_kwargs ) tokenizer = CLIPTokenizer.from_pretrained( @@ -1160,7 +1143,7 @@ def create_text_encoders_and_tokenizers( elif model_type == "FrozenCLIPEmbedder": try: config_name = "openai/clip-vit-large-patch14" - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + text_encoder = create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) except Exception: @@ -1177,7 +1160,7 @@ def create_text_encoders_and_tokenizers( try: tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) - text_encoder_2 = convert_open_clip_checkpoint( + text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( checkpoint, config_name, prefix=prefix, @@ -1185,8 +1168,7 @@ def create_text_encoders_and_tokenizers( local_files_only=local_files_only, **config_kwargs, ) - except Exception as e: - raise e + except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'." ) @@ -1203,10 +1185,9 @@ def create_text_encoders_and_tokenizers( try: config_name = "openai/clip-vit-large-patch14" tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + text_encoder = create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - except Exception as e: - raise e + except Exception: raise ValueError( f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'." ) @@ -1216,7 +1197,7 @@ def create_text_encoders_and_tokenizers( config_kwargs = {"projection_dim": 1280} prefix = "conditioner.embedders.1.model." tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) - text_encoder_2 = convert_open_clip_checkpoint( + text_encoder_2 = create_text_encoder_from_open_clip_checkpoint( checkpoint, config_name, prefix=prefix, @@ -1239,12 +1220,17 @@ def create_text_encoders_and_tokenizers( return -def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs): +def create_scheduler_from_ldm( + pipeline_class_name, + original_config, + checkpoint, + prediction_type=None, + scheduler_type="ddim", + model_type=None, +): scheduler_config = get_default_scheduler_config() - model_type = infer_model_type(pipeline_class_name, original_config) + model_type = infer_model_type(original_config, model_type=model_type) - scheduler_type = kwargs.get("scheduler_type", "ddim") - prediction_type = kwargs.get("prediction_type", None) global_step = checkpoint["global_step"] if "global_step" in checkpoint else None num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000