From fd2ec363692f246cc331c6e6ab53dccf772e21cf Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Fri, 19 Jan 2024 06:13:20 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 38 +++---------------- .../pipeline_stable_diffusion_xl_img2img.py | 1 - 2 files changed, 6 insertions(+), 33 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 4bb7f330fa..1896440b6c 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -55,31 +55,6 @@ REFINER_PIPELINES = [ "StableDiffusionXLControlNetImg2ImgPipeline", ] -LOADABLE_CLASSES = { - "diffusers": { - "ControlNetModel": "create_controlnet_model", - "AutoencoderKL": "create_vae_model", - "UNet2DConditionModel": "create_unet_model", - } -} - - -def extract_pipeline_component_names(pipeline_class): - components = inspect.signature(pipeline_class).parameters.keys() - return components - - -def check_valid_url(pretrained_model_link_or_path): - # check if url prefix is valid - # remove huggingface url prefix from model path - has_valid_url_prefix = False - for prefix in VALID_URL_PREFIXES: - if pretrained_model_link_or_path.startswith(prefix): - pretrained_model_link_or_path = pretrained_model_link_or_path[len(prefix) :] - has_valid_url_prefix = True - - return has_valid_url_prefix, pretrained_model_link_or_path - def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)" @@ -174,8 +149,9 @@ def set_additional_components( **kwargs, ): components = {} + model_type = kwargs.get("model_type", None) if pipeline_class_name in REFINER_PIPELINES: - model_type = infer_model_type(pipeline_class_name, original_config) + model_type = infer_model_type(original_config, model_type=model_type) is_refiner = model_type == "SDXL-Refiner" components.update( { @@ -206,8 +182,7 @@ class FromSingleFileMixin: `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the - dtype is automatically derived from the model's weights. + Override the default `torch.dtype` and load the model with another dtype. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -291,6 +266,7 @@ class FromSingleFileMixin: ) checkpoint = load_state_dict(checkpoint_path) + # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] @@ -340,11 +316,9 @@ class FromSingleFileMixin: continue init_kwargs.update(components) - additional_components = set(optional_kwargs - init_kwargs.keys()) + additional_components = set_additional_components(class_name, original_config, **kwargs) if additional_components: - components = set_additional_components(class_name, original_config, **kwargs) - if components: - init_kwargs.update(components) + init_kwargs.update(additional_components) init_kwargs.update(passed_pipe_kwargs) pipe = pipeline_class(**init_kwargs) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 1c22affba1..4e95a876ce 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -792,7 +792,6 @@ class StableDiffusionXLImg2ImgPipeline( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim