mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix: un-existing tmp config file in linux, avoid unnecessary disk IO (#2591)
This commit is contained in:
@@ -14,9 +14,8 @@
|
||||
# limitations under the License.
|
||||
""" Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
@@ -1046,31 +1045,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
if original_config_file is None:
|
||||
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
|
||||
original_config_file = os.path.join(tmpdir, "inference.yaml")
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
if not os.path.isfile("v2-inference-v.yaml"):
|
||||
# model_type = "v2"
|
||||
r = requests.get(
|
||||
" https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
open(original_config_file, "wb").write(r.content)
|
||||
# model_type = "v1"
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
else:
|
||||
if not os.path.isfile("v1-inference.yaml"):
|
||||
# model_type = "v1"
|
||||
r = requests.get(
|
||||
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
open(original_config_file, "wb").write(r.content)
|
||||
if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
Reference in New Issue
Block a user