1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix: un-existing tmp config file in linux, avoid unnecessary disk IO (#2591)

This commit is contained in:
Víctor Martínez
2023-03-08 20:20:09 +01:00
committed by GitHub
parent cbbad0af69
commit 186689affd

View File

@@ -14,9 +14,8 @@
# limitations under the License.
""" Conversion script for the Stable Diffusion checkpoints."""
import os
import re
import tempfile
from io import BytesIO
from typing import Optional
import requests
@@ -1046,31 +1045,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
with tempfile.TemporaryDirectory() as tmpdir:
if original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
if original_config_file is None:
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
original_config_file = os.path.join(tmpdir, "inference.yaml")
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
if not os.path.isfile("v2-inference-v.yaml"):
# model_type = "v2"
r = requests.get(
" https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
)
open(original_config_file, "wb").write(r.content)
# model_type = "v1"
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
else:
if not os.path.isfile("v1-inference.yaml"):
# model_type = "v1"
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
open(original_config_file, "wb").write(r.content)
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
# model_type = "v2"
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
original_config = OmegaConf.load(original_config_file)
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels