mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -346,7 +346,7 @@ class FromSingleFileMixin:
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
original_config = fetch_original_config(checkpoint, original_config_file, config_files)
|
||||
original_config = fetch_original_config(pipeline_name, checkpoint, original_config_file, config_files)
|
||||
component_names = extract_pipeline_component_names(cls)
|
||||
|
||||
pipeline_components = {}
|
||||
|
||||
@@ -110,7 +110,7 @@ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def fetch_original_config_file_from_url(checkpoint):
|
||||
def fetch_original_config_file_from_url(pipeline_class_name, checkpoint):
|
||||
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
|
||||
config_url = CONFIG_URLS["v2"]
|
||||
|
||||
@@ -120,44 +120,41 @@ def fetch_original_config_file_from_url(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["xl_refiner"] in checkpoint:
|
||||
config_url = CONFIG_URLS["xl_refiner"]
|
||||
|
||||
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
config_url = CONFIG_URLS["upscale"]
|
||||
|
||||
else:
|
||||
config_url = CONFIG_URLS["v1"]
|
||||
|
||||
# TODO: Add upscale config
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
return original_config_file
|
||||
|
||||
|
||||
def fetch_original_config_file_from_file(checkpoint, config_files: list):
|
||||
if "v1" in config_files:
|
||||
return config_files["v1"]
|
||||
|
||||
def fetch_original_config_file_from_file(config_files: list):
|
||||
if "v2" in config_files:
|
||||
return config_files["v2"]
|
||||
|
||||
if "xl" in config_files:
|
||||
elif "xl" in config_files:
|
||||
return config_files["xl"]
|
||||
|
||||
if "xl_refiner" in config_files:
|
||||
elif "xl_refiner" in config_files:
|
||||
return config_files["xl_refiner"]
|
||||
|
||||
# TODO: Add upscale config
|
||||
|
||||
return
|
||||
else:
|
||||
return config_files["v1"]
|
||||
|
||||
|
||||
def fetch_original_config(checkpoint, original_config_file=None, config_files=None):
|
||||
def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None, config_files=None):
|
||||
if original_config_file:
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
return original_config
|
||||
|
||||
elif config_files:
|
||||
original_config_file = fetch_original_config_file_from_file(checkpoint, config_files)
|
||||
original_config_file = fetch_original_config_file_from_file(config_files)
|
||||
|
||||
else:
|
||||
original_config_file = fetch_original_config_file_from_url(checkpoint)
|
||||
original_config_file = fetch_original_config_file_from_url(pipeline_class_name, checkpoint)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
|
||||
@@ -1436,7 +1436,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if pipeline_class == StableDiffusionUpscalePipeline:
|
||||
image_size = original_config.model.params.unet_config.params.image_size
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
|
||||
Reference in New Issue
Block a user