From 0cd1be42d3b40d6804ded59437220d9079580f45 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Dec 2023 05:09:03 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 46 ++++++++++++-- src/diffusers/loaders/single_file_utils.py | 72 +++++++++++++++++++++- 2 files changed, 112 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 9f5c3d48d1..809040b62e 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -55,6 +55,15 @@ MODEL_TYPE_FROM_PIPELINE_CLASS = { "StableUnCLIPPipeline": "FrozenOpenCLIPEmbedder", "StableUnCLIPImg2ImgPipeline": "FrozenOpenCLIPEmbedder", } +PIPELINE_COMPONENTS = { + "unet": , + "vae": "AutoencoderKL", + "text_encoder": "CLIPTextModel", + "text_encoder_2": "CLIPTextModel", + "tokenizer": "CLIPTokenizer", + "tokenizer_2": "CLIPTokenizer", + "scheduler": "DiffusionScheduler", +} def extract_pipeline_component_names(pipeline_class): @@ -120,12 +129,41 @@ def infer_model_type(pipeline_class_name): return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None) -def build_component(component_name, original_config, checkpoint, **kwargs): +def build_component(pipeline_class_name, component_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs): if component_name in kwargs: return kwargs.pop(component_name, None) - component_class = getattr(importlib.import_module("diffusers"), component_name) + if component_name == "unet": + unet = create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) + return unet + if component_name == "controlnet": + controlnet = create_controlnet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) + return controlnet + + if component_name == "vae": + vae = create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs) + return vae + + if component_name in ["text_encoder", "text_encoder_2"]: + text_encoder = create_text_encoder_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) + return text_encoder + + if component_name in ["tokenizer", "tokenizer_2"]: + tokenizer = create_tokenizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) + return tokenizer + + if component_name == "scheduler": + scheduler = create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) + return scheduler + + if component_name == "image_normalizer": + image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) + return image_normalizer + + if component_name == "image_normalizer": + image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) + return image_normalizer return @@ -332,7 +370,7 @@ class FromSingleFileMixin: pipeline_components = {} for component in component_names: - pipeline_components[component] = build_component(component, checkpoint, original_config, **kwargs) + pipeline_components[component] = build_component(pipeline_class_name, component, checkpoint, original_config, **kwargs) pipe = download_from_original_stable_diffusion_ckpt( pretrained_model_link_or_path, @@ -359,7 +397,7 @@ class FromSingleFileMixin: local_files_only=local_files_only, ) - pipe = cls(**pipeline_components, **kwargs) + pipe = cls(**pipeline_components) if torch_dtype is not None: pipe.to(dtype=torch_dtype) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 0effa4d826..76949f42af 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -156,7 +156,7 @@ def set_model_type(original_config, model_type=None): else: raise ValueError("Unable to infer model type from config") - logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") + logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}") return model_type @@ -897,10 +897,78 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder return text_model -def create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs): +def convert_controlnet_checkpoint( + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + + ctx = init_empty_weights if is_accelerate_available() else nullcontext + with ctx(): + controlnet = ControlNetModel(**ctrlnet_config) + + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, + ) + + if is_accelerate_available(): + for param_name, param in converted_ctrl_checkpoint.items(): + set_module_tensor_to_device(controlnet, param_name, "cpu", value=param) + else: + controlnet.load_state_dict(converted_ctrl_checkpoint) + + return controlnet + + +def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs): + if "num_in_channels" in kwargs: + num_in_channels = kwargs.pop("num_in_channels") + elif pipeline_class_name in [ + "StableDiffusionInpaintPipeline", + "StableDiffusionXLInpaintPipeline", + "StableDiffusionXLControlNetInpaintPipeline"]: + num_in_channels = 9 + elif pipeline_class_name == "StableDiffusionUpscalePipeline": + num_in_channels = 7 + else: + num_in_channels = 4 + + if "upcast_attention" in kwargs: + upcast_attention = kwargs.pop("upcast_attention") + extract_ema = kwargs.get("extract_ema", False) unet_config = create_unet_diffusers_config(original_config, image_size=image_size) + unet_config["num_in_channels"] = num_in_channels + unet_config["upcast_attention"] = upcast_attention + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint( checkpoint, unet_config, path=path, extract_ema=extract_ema