diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index e8e21edb71..f827b8ca45 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -346,7 +346,7 @@ class FromSingleFileMixin: while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - original_config = fetch_original_config(checkpoint, original_config_file, config_files) + original_config = fetch_original_config(pipeline_name, checkpoint, original_config_file, config_files) component_names = extract_pipeline_component_names(cls) pipeline_components = {} diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index af3b2b7fa8..fe8e7d3099 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -110,7 +110,7 @@ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} textenc_pattern = re.compile("|".join(protected.keys())) -def fetch_original_config_file_from_url(checkpoint): +def fetch_original_config_file_from_url(pipeline_class_name, checkpoint): if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024: config_url = CONFIG_URLS["v2"] @@ -120,44 +120,41 @@ def fetch_original_config_file_from_url(checkpoint): elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint: config_url = CONFIG_URLS["xl_refiner"] + elif pipeline_class_name == "StableDiffusionUpscalePipeline": + config_url = CONFIG_URLS["upscale"] + else: config_url = CONFIG_URLS["v1"] - # TODO: Add upscale config - original_config_file = BytesIO(requests.get(config_url).content) return original_config_file -def fetch_original_config_file_from_file(checkpoint, config_files: list): - if "v1" in config_files: - return config_files["v1"] - +def fetch_original_config_file_from_file(config_files: list): if "v2" in config_files: return config_files["v2"] - if "xl" in config_files: + elif "xl" in config_files: return config_files["xl"] - if "xl_refiner" in config_files: + elif "xl_refiner" in config_files: return config_files["xl_refiner"] - # TODO: Add upscale config - - return + else: + return config_files["v1"] -def fetch_original_config(checkpoint, original_config_file=None, config_files=None): +def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None, config_files=None): if original_config_file: original_config = OmegaConf.load(original_config_file) return original_config elif config_files: - original_config_file = fetch_original_config_file_from_file(checkpoint, config_files) + original_config_file = fetch_original_config_file_from_file(config_files) else: - original_config_file = fetch_original_config_file_from_url(checkpoint) + original_config_file = fetch_original_config_file_from_url(pipeline_class_name, checkpoint) original_config = OmegaConf.load(original_config_file) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index d80f4dc863..5aa23252b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1436,7 +1436,6 @@ def download_from_original_stable_diffusion_ckpt( if pipeline_class == StableDiffusionUpscalePipeline: image_size = original_config.model.params.unet_config.params.image_size - import ipdb; ipdb.set_trace() # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size)