1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2023-12-30 08:45:48 +00:00
parent 5daf61a342
commit af6cd361e2
3 changed files with 13 additions and 17 deletions

View File

@@ -346,7 +346,7 @@ class FromSingleFileMixin:
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
original_config = fetch_original_config(checkpoint, original_config_file, config_files)
original_config = fetch_original_config(pipeline_name, checkpoint, original_config_file, config_files)
component_names = extract_pipeline_component_names(cls)
pipeline_components = {}

View File

@@ -110,7 +110,7 @@ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
def fetch_original_config_file_from_url(checkpoint):
def fetch_original_config_file_from_url(pipeline_class_name, checkpoint):
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
config_url = CONFIG_URLS["v2"]
@@ -120,44 +120,41 @@ def fetch_original_config_file_from_url(checkpoint):
elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
config_url = CONFIG_URLS["xl_refiner"]
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
config_url = CONFIG_URLS["upscale"]
else:
config_url = CONFIG_URLS["v1"]
# TODO: Add upscale config
original_config_file = BytesIO(requests.get(config_url).content)
return original_config_file
def fetch_original_config_file_from_file(checkpoint, config_files: list):
if "v1" in config_files:
return config_files["v1"]
def fetch_original_config_file_from_file(config_files: list):
if "v2" in config_files:
return config_files["v2"]
if "xl" in config_files:
elif "xl" in config_files:
return config_files["xl"]
if "xl_refiner" in config_files:
elif "xl_refiner" in config_files:
return config_files["xl_refiner"]
# TODO: Add upscale config
return
else:
return config_files["v1"]
def fetch_original_config(checkpoint, original_config_file=None, config_files=None):
def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None, config_files=None):
if original_config_file:
original_config = OmegaConf.load(original_config_file)
return original_config
elif config_files:
original_config_file = fetch_original_config_file_from_file(checkpoint, config_files)
original_config_file = fetch_original_config_file_from_file(config_files)
else:
original_config_file = fetch_original_config_file_from_url(checkpoint)
original_config_file = fetch_original_config_file_from_url(pipeline_class_name, checkpoint)
original_config = OmegaConf.load(original_config_file)

View File

@@ -1436,7 +1436,6 @@ def download_from_original_stable_diffusion_ckpt(
if pipeline_class == StableDiffusionUpscalePipeline:
image_size = original_config.model.params.unet_config.params.image_size
import ipdb; ipdb.set_trace()
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)