diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 81bbbdeea7..bc3d439b60 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -14,9 +14,8 @@ # limitations under the License. """ Conversion script for the Stable Diffusion checkpoints.""" -import os import re -import tempfile +from io import BytesIO from typing import Optional import requests @@ -1046,31 +1045,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt( if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - with tempfile.TemporaryDirectory() as tmpdir: - if original_config_file is None: - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if original_config_file is None: + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - original_config_file = os.path.join(tmpdir, "inference.yaml") - if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: - if not os.path.isfile("v2-inference-v.yaml"): - # model_type = "v2" - r = requests.get( - " https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - ) - open(original_config_file, "wb").write(r.content) + # model_type = "v1" + config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - if global_step == 110000: - # v2.1 needs to upcast attention - upcast_attention = True - else: - if not os.path.isfile("v1-inference.yaml"): - # model_type = "v1" - r = requests.get( - " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" - ) - open(original_config_file, "wb").write(r.content) + if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024: + # model_type = "v2" + config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" - original_config = OmegaConf.load(original_config_file) + if global_step == 110000: + # v2.1 needs to upcast attention + upcast_attention = True + + original_config_file = BytesIO(requests.get(config_url).content) + + original_config = OmegaConf.load(original_config_file) if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels