mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve single loading file (#4041)
* start improving single file load * Fix more * start improving single file load * Fix sd 2.1 * further improve from_single_file
This commit is contained in:
committed by
GitHub
parent
6632823690
commit
8bff782354
@@ -1389,7 +1389,7 @@ class FromSingleFileMixin:
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", 512)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
|
||||
@@ -24,6 +24,7 @@ from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
@@ -48,7 +49,7 @@ from ...schedulers import (
|
||||
PNDMScheduler,
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from ...utils import is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
@@ -57,6 +58,10 @@ from .safety_checker import StableDiffusionSafetyChecker
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -770,11 +775,12 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None):
|
||||
text_model = (
|
||||
CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
||||
if text_encoder is None
|
||||
else text_encoder
|
||||
)
|
||||
if text_encoder is None:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
config = CLIPTextConfig.from_pretrained(config_name)
|
||||
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModel(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@@ -787,7 +793,8 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
|
||||
if key.startswith(prefix):
|
||||
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
|
||||
return text_model
|
||||
|
||||
@@ -884,14 +891,26 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
return model
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
|
||||
def convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="cond_stage_model.model.", has_projection=False, **config_kwargs
|
||||
):
|
||||
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
|
||||
)
|
||||
# text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||
# "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
|
||||
# )
|
||||
config = CLIPTextConfig.from_pretrained(config_name, **config_kwargs)
|
||||
|
||||
with init_empty_weights():
|
||||
text_model = CLIPTextModelWithProjection(config) if has_projection else CLIPTextModel(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
keys_to_ignore = []
|
||||
if config_name == "stabilityai/stable-diffusion-2" and config.num_hidden_layers == 23:
|
||||
# make sure to remove all keys > 22
|
||||
keys_to_ignore += [k for k in keys if k.startswith("cond_stage_model.model.transformer.resblocks.23")]
|
||||
keys_to_ignore += ["cond_stage_model.model.text_projection"]
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
if prefix + "text_projection" in checkpoint:
|
||||
@@ -902,8 +921,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
|
||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||
|
||||
for key in keys:
|
||||
# if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
# continue
|
||||
if key in keys_to_ignore:
|
||||
continue
|
||||
if key[len(prefix) :] in textenc_conversion_map:
|
||||
if key.endswith("text_projection"):
|
||||
value = checkpoint[key].T
|
||||
@@ -931,7 +950,8 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
|
||||
|
||||
text_model_dict[new_key] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
|
||||
return text_model
|
||||
|
||||
@@ -1061,7 +1081,7 @@ def convert_controlnet_checkpoint(
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
image_size: int = 512,
|
||||
image_size: Optional[int] = None,
|
||||
prediction_type: str = None,
|
||||
model_type: str = None,
|
||||
extract_ema: bool = False,
|
||||
@@ -1144,6 +1164,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
LDMTextToImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
@@ -1166,12 +1187,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||
else:
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -1183,7 +1201,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
logger.warning("global_step key not found in model")
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
@@ -1230,9 +1248,15 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
if image_size is None:
|
||||
image_size = 1024
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
|
||||
num_in_channels = 9
|
||||
elif num_in_channels is None:
|
||||
num_in_channels = 4
|
||||
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
@@ -1263,7 +1287,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
image_size = 1024
|
||||
scheduler_dict = {
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
@@ -1279,7 +1302,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
}
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler_dict)
|
||||
scheduler_type = "euler"
|
||||
vae_path = "stabilityai/sdxl-vae"
|
||||
else:
|
||||
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
|
||||
@@ -1318,25 +1340,45 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
with init_empty_weights():
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||
)
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
# Convert the VAE model.
|
||||
if vae_path is None:
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
if (
|
||||
"model" in original_config
|
||||
and "params" in original_config.model
|
||||
and "scale_factor" in original_config.model.params
|
||||
):
|
||||
vae_scaling_factor = original_config.model.params.scale_factor
|
||||
else:
|
||||
vae_scaling_factor = 0.18215 # default SD scaling factor
|
||||
|
||||
vae_config["scaling_factor"] = vae_scaling_factor
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
for param_name, param in converted_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained(vae_path)
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
|
||||
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
|
||||
if stable_unclip is None:
|
||||
@@ -1469,7 +1511,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
|
||||
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")
|
||||
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLPipeline(
|
||||
vae=vae,
|
||||
@@ -1485,7 +1532,12 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
tokenizer = None
|
||||
text_encoder = None
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
|
||||
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.")
|
||||
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, prefix="conditioner.embedders.0.model.", has_projection=True, **config_kwargs
|
||||
)
|
||||
|
||||
pipe = StableDiffusionXLImg2ImgPipeline(
|
||||
vae=vae,
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
@@ -153,7 +153,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionInpaintPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion.
|
||||
|
||||
|
||||
@@ -20,17 +20,20 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionInpaintPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
|
||||
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -512,6 +515,42 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 6e-4
|
||||
|
||||
def test_download_local(self):
|
||||
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 1
|
||||
image_out = pipe(**inputs).images[0]
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
image_ckpt = pipe(**inputs).images[0]
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
image = pipe(**inputs).images[0]
|
||||
|
||||
assert np.max(np.abs(image - image_ckpt)) < 1e-4
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -19,6 +19,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -29,6 +30,7 @@ from diffusers import (
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils import load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
|
||||
|
||||
@@ -426,6 +428,40 @@ class StableDiffusion2VPredictionPipelineIntegrationTests(unittest.TestCase):
|
||||
assert image.shape == (768, 768, 3)
|
||||
assert np.abs(expected_image - image).max() < 7.5e-1
|
||||
|
||||
def test_download_local(self):
|
||||
filename = hf_hub_download("stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.safetensors")
|
||||
|
||||
pipe = StableDiffusionPipeline.from_single_file(filename, torch_dtype=torch.float16)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
image_out = pipe("test", num_inference_steps=1, output_type="np").images[0]
|
||||
|
||||
assert image_out.shape == (768, 768, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
single_file_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
|
||||
)
|
||||
|
||||
pipe_single = StableDiffusionPipeline.from_single_file(single_file_path)
|
||||
pipe_single.scheduler = DDIMScheduler.from_config(pipe_single.scheduler.config)
|
||||
pipe_single.unet.set_attn_processor(AttnProcessor())
|
||||
pipe_single.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image_ckpt = pipe_single("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1")
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.unet.set_attn_processor(AttnProcessor())
|
||||
pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipe("a turtle", num_inference_steps=5, generator=generator, output_type="np").images[0]
|
||||
|
||||
assert np.max(np.abs(image - image_ckpt)) < 1e-3
|
||||
|
||||
def test_stable_diffusion_text2img_intermediate_state_v_pred(self):
|
||||
number_of_steps = 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user