mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -31,7 +31,14 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
from ..utils.import_utils import BACKENDS_MAPPING
|
||||
from .single_file_utils import download_from_original_stable_diffusion_ckpt, fetch_original_config
|
||||
from .single_file_utils import (
|
||||
create_scheduler_components,
|
||||
create_stable_unclip_components,
|
||||
create_unet_model,
|
||||
create_vae_model,
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
fetch_original_config,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -43,26 +50,13 @@ if is_accelerate_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
DIFFUSER_PIPELINE_CONFIGS = {
|
||||
"StableDiffusionPipeline": None,
|
||||
"StableDiffusionImg2ImgPipeline": None,
|
||||
"StableDiffusionInpaintPipeline": None,
|
||||
"StableDiffusionControlNetPipeline": None,
|
||||
}
|
||||
|
||||
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
MODEL_TYPE_FROM_PIPELINE_CLASS = {
|
||||
TEXT_ENCODER_FROM_PIPELINE_CLASS = {
|
||||
"StableUnCLIPPipeline": "FrozenOpenCLIPEmbedder",
|
||||
"StableUnCLIPImg2ImgPipeline": "FrozenOpenCLIPEmbedder",
|
||||
}
|
||||
PIPELINE_COMPONENTS = {
|
||||
"unet": ,
|
||||
"vae": "AutoencoderKL",
|
||||
"text_encoder": "CLIPTextModel",
|
||||
"text_encoder_2": "CLIPTextModel",
|
||||
"tokenizer": "CLIPTokenizer",
|
||||
"tokenizer_2": "CLIPTokenizer",
|
||||
"scheduler": "DiffusionScheduler",
|
||||
"LDMTextToImagePipeline": "LDMTextToImage",
|
||||
"PaintByExamplePipeline": "PaintByExample",
|
||||
"StableDiffusion": "stable-diffusion",
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +76,16 @@ def check_valid_url(pretrained_model_link_or_path):
|
||||
return has_valid_url_prefix
|
||||
|
||||
|
||||
def download_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None):
|
||||
def download_model_checkpoint(
|
||||
ckpt_path,
|
||||
cache_dir=None,
|
||||
resume_download=False,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
revision=None,
|
||||
):
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
@@ -125,50 +128,96 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def infer_model_type(pipeline_class_name):
|
||||
return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None)
|
||||
|
||||
|
||||
def build_component(pipeline_class_name, component_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
def build_component(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
checkpoint_path_or_dict,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in kwargs:
|
||||
return kwargs.pop(component_name, None)
|
||||
|
||||
if component_name == "unet":
|
||||
unet = create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return unet
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
if component_name == "controlnet":
|
||||
controlnet = create_controlnet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return controlnet
|
||||
if component_name == "unet":
|
||||
unet_components = create_unet_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
vae = create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs)
|
||||
return vae
|
||||
vae_components = create_vae_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return vae_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2"]:
|
||||
text_encoder = create_text_encoder_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return text_encoder
|
||||
if component_name == "controlnet":
|
||||
controlnet_components = create_controlnet_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return controlnet_components
|
||||
|
||||
if component_name in ["tokenizer", "tokenizer_2"]:
|
||||
tokenizer = create_tokenizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return tokenizer
|
||||
if component_name == "adapter":
|
||||
adapter_components = create_adapter_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return adapter_components
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler = create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return scheduler
|
||||
|
||||
if component_name == "image_normalizer":
|
||||
image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return image_normalizer
|
||||
|
||||
if component_name == "image_normalizer":
|
||||
image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return image_normalizer
|
||||
scheduler_components = create_scheduler(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return scheduler_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
text_encoder_components = create_text_encoders_and_tokenizers(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
return
|
||||
|
||||
|
||||
def build_additional_components(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
checkpoint_path_or_dict,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in kwargs:
|
||||
return kwargs.pop(component_name, None)
|
||||
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
|
||||
if pipeline_class_name == ["StableUnCLIPPipeline", "StableUnCLIPImg2ImgPipeline"]:
|
||||
stable_unclip_components = create_stable_unclip_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return stable_unclip_components
|
||||
|
||||
if pipeline_class_name == "LDMTextToImagePipeline":
|
||||
ldm_text_to_image_components = create_ldm_text_to_image_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return ldm_text_to_image_components
|
||||
|
||||
if pipeline_class_name == "PaintByExamplePipeline":
|
||||
paint_by_example_components = create_paint_by_example_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
)
|
||||
return paint_by_example_components
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
@@ -281,30 +330,22 @@ class FromSingleFileMixin:
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
|
||||
extract_ema = kwargs.pop("extract_ema", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scheduler_type = kwargs.pop("scheduler_type", "pndm")
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
upcast_attention = kwargs.pop("upcast_attention", None)
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", True)
|
||||
prediction_type = kwargs.pop("prediction_type", None)
|
||||
text_encoder = kwargs.pop("text_encoder", None)
|
||||
text_encoder_2 = kwargs.pop("text_encoder_2", None)
|
||||
vae = kwargs.pop("vae", None)
|
||||
controlnet = kwargs.pop("controlnet", None)
|
||||
adapter = kwargs.pop("adapter", None)
|
||||
tokenizer = kwargs.pop("tokenizer", None)
|
||||
tokenizer_2 = kwargs.pop("tokenizer_2", None)
|
||||
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
pipeline_name = cls.__name__
|
||||
file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1]
|
||||
@@ -313,42 +354,7 @@ class FromSingleFileMixin:
|
||||
if from_safetensors and use_safetensors is False:
|
||||
raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.")
|
||||
|
||||
# TODO: For now we only support stable diffusion
|
||||
stable_unclip = None
|
||||
model_type = None
|
||||
|
||||
if pipeline_name in [
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
]:
|
||||
from ..models.controlnet import ControlNetModel
|
||||
from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# list/tuple or a single instance of ControlNetModel or MultiControlNetModel
|
||||
if not (
|
||||
isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
|
||||
or isinstance(controlnet, (list, tuple))
|
||||
and isinstance(controlnet[0], ControlNetModel)
|
||||
):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
pass
|
||||
elif pipeline_name == "StableUnCLIPPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "txt2img"
|
||||
elif pipeline_name == "StableUnCLIPImg2ImgPipeline":
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
stable_unclip = "img2img"
|
||||
elif pipeline_name == "PaintByExamplePipeline":
|
||||
model_type = "PaintByExample"
|
||||
elif pipeline_name == "LDMTextToImagePipeline":
|
||||
model_type = "LDMTextToImage"
|
||||
else:
|
||||
raise ValueError(f"Unhandled pipeline class: {pipeline_name}")
|
||||
|
||||
has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path)
|
||||
has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path)
|
||||
|
||||
# Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained
|
||||
ckpt_path = Path(pretrained_model_link_or_path)
|
||||
@@ -356,9 +362,16 @@ class FromSingleFileMixin:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}"
|
||||
)
|
||||
pretrained_model_link_or_path = download_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision)
|
||||
pretrained_model_link_or_path = download_model_checkpoint(
|
||||
ckpt_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||
@@ -370,32 +383,15 @@ class FromSingleFileMixin:
|
||||
|
||||
pipeline_components = {}
|
||||
for component in component_names:
|
||||
pipeline_components[component] = build_component(pipeline_name, component, checkpoint, original_config, **kwargs)
|
||||
components = build_component(
|
||||
pipeline_components, pipeline_name, component, checkpoint, original_config, **kwargs
|
||||
)
|
||||
pipeline_components.update(components)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
pipeline_class=cls,
|
||||
model_type=model_type,
|
||||
stable_unclip=stable_unclip,
|
||||
controlnet=controlnet,
|
||||
adapter=adapter,
|
||||
from_safetensors=from_safetensors,
|
||||
extract_ema=extract_ema,
|
||||
image_size=image_size,
|
||||
scheduler_type=scheduler_type,
|
||||
num_in_channels=num_in_channels,
|
||||
upcast_attention=upcast_attention,
|
||||
load_safety_checker=load_safety_checker,
|
||||
prediction_type=prediction_type,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
additional_components = set(pipeline_components.keys() - component_names)
|
||||
if additional_components:
|
||||
components = build_additional_components(pipeline_name, component, checkpoint, original_config, **kwargs)
|
||||
pipeline_components.update(components)
|
||||
|
||||
pipe = cls(**pipeline_components)
|
||||
|
||||
|
||||
@@ -26,16 +26,21 @@ from safetensors.torch import load_file as safe_load
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
BertTokenizerFast,
|
||||
CLIPImageProcessor,
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModel,
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPVisionTextModel,
|
||||
CLIPVisionTextModelWithProjection,
|
||||
)
|
||||
|
||||
from ...models import (
|
||||
AutoencoderKL,
|
||||
PriorTransformer,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from ...schedulers import (
|
||||
from ..models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ..pipelines.stable_unclip.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
from ..schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
@@ -46,9 +51,8 @@ from ...schedulers import (
|
||||
PNDMScheduler,
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from ...utils import is_accelerate_available, is_omegaconf_available, logging
|
||||
from ...utils.import_utils import BACKENDS_MAPPING
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..utils import is_accelerate_available, is_omegaconf_available, logging
|
||||
from ..utils.import_utils import BACKENDS_MAPPING
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -62,7 +66,7 @@ CONFIG_URLS = {
|
||||
"v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
|
||||
"v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml",
|
||||
"xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml",
|
||||
"upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
|
||||
"upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml",
|
||||
}
|
||||
|
||||
CHECKPOINT_KEY_NAMES = {
|
||||
@@ -71,6 +75,20 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
}
|
||||
|
||||
SCHEDULER_DEFAULT_CONFIG = {
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"beta_end": 0.012,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": False,
|
||||
"skip_prk_steps": True,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
|
||||
textenc_conversion_lst = [
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
@@ -109,7 +127,7 @@ def fetch_original_config_file_from_url(checkpoint):
|
||||
else:
|
||||
config_url = CONFIG_URLS["v1"]
|
||||
|
||||
#TODO: Add upscale config
|
||||
# TODO: Add upscale config
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
@@ -129,7 +147,7 @@ def fetch_original_config_file_from_file(checkpoint, config_files: list):
|
||||
if "xl_refiner" in config_files:
|
||||
return config_files["xl_refiner"]
|
||||
|
||||
#TODO: Add upscale config
|
||||
# TODO: Add upscale config
|
||||
|
||||
return
|
||||
|
||||
@@ -162,12 +180,21 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def set_model_type(original_config, model_type=None):
|
||||
def infer_model_type(pipeline_class_name, original_config, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
has_cond_stage_config = "cond_stage_config" in original_config.model.params and original_config.model.params.cond_stage_config is not None
|
||||
has_network_config = "network_config" in original_config.model.params and original_config.model.params.network_config is not None
|
||||
if pipeline_class_name in ["StableUnCLIPPipeline", "StableUnCLIPImg2ImgPipeline"]:
|
||||
model_type = "FrozenOpenCLIPEmbedder"
|
||||
return model_type
|
||||
|
||||
has_cond_stage_config = (
|
||||
"cond_stage_config" in original_config.model.params
|
||||
and original_config.model.params.cond_stage_config is not None
|
||||
)
|
||||
has_network_config = (
|
||||
"network_config" in original_config.model.params and original_config.model.params.network_config is not None
|
||||
)
|
||||
|
||||
if has_cond_stage_config:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
@@ -185,6 +212,11 @@ def set_model_type(original_config, model_type=None):
|
||||
|
||||
return model_type
|
||||
|
||||
|
||||
def get_default_scheduler_config():
|
||||
return SCHEDULER_DEFAULT_CONFIG
|
||||
|
||||
|
||||
def shave_segments(path, n_shave_prefix_segments=1):
|
||||
"""
|
||||
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
||||
@@ -350,6 +382,7 @@ def conv_attn_to_linear(checkpoint):
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
@@ -971,6 +1004,7 @@ def convert_controlnet_checkpoint(
|
||||
|
||||
return controlnet
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
@@ -1053,22 +1087,173 @@ def convert_open_clip_checkpoint(
|
||||
return text_model
|
||||
|
||||
|
||||
def stable_unclip_image_encoder(original_config, local_files_only=False):
|
||||
"""
|
||||
Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
||||
|
||||
We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
||||
encoders.
|
||||
"""
|
||||
|
||||
image_embedder_config = original_config.model.params.embedder_config
|
||||
|
||||
sd_clip_image_embedder_class = image_embedder_config.target
|
||||
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
||||
|
||||
if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
||||
clip_model_name = image_embedder_config.params.model
|
||||
|
||||
if clip_model_name == "ViT-L/14":
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"openai/clip-vit-large-patch14", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
||||
|
||||
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
||||
feature_extractor = CLIPImageProcessor()
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
||||
)
|
||||
|
||||
return feature_extractor, image_encoder
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
# load clip vision
|
||||
model.model.load_state_dict(text_model_dict)
|
||||
|
||||
# load mapper
|
||||
keys_mapper = {
|
||||
k[len("cond_stage_model.mapper.res") :]: v
|
||||
for k, v in checkpoint.items()
|
||||
if k.startswith("cond_stage_model.mapper")
|
||||
}
|
||||
|
||||
MAPPING = {
|
||||
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
||||
"attn.c_proj": ["attn1.to_out.0"],
|
||||
"ln_1": ["norm1"],
|
||||
"ln_2": ["norm3"],
|
||||
"mlp.c_fc": ["ff.net.0.proj"],
|
||||
"mlp.c_proj": ["ff.net.2"],
|
||||
}
|
||||
|
||||
mapped_weights = {}
|
||||
for key, value in keys_mapper.items():
|
||||
prefix = key[: len("blocks.i")]
|
||||
suffix = key.split(prefix)[-1].split(".")[-1]
|
||||
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
||||
mapped_names = MAPPING[name]
|
||||
|
||||
num_splits = len(mapped_names)
|
||||
for i, mapped_name in enumerate(mapped_names):
|
||||
new_name = ".".join([prefix, mapped_name, suffix])
|
||||
shape = value.shape[0] // num_splits
|
||||
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
||||
|
||||
model.mapper.load_state_dict(mapped_weights)
|
||||
|
||||
# load final layer norm
|
||||
model.final_layer_norm.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
||||
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load final proj
|
||||
model.proj_out.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["proj_out.bias"],
|
||||
"weight": checkpoint["proj_out.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load uncond vector
|
||||
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
||||
return model
|
||||
|
||||
|
||||
def stable_unclip_image_noising_components(
|
||||
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Returns the noising components for the img2img and txt2img unclip pipelines.
|
||||
|
||||
Converts the stability noise augmentor into
|
||||
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
||||
2. a `DDPMScheduler` for holding the noise schedule
|
||||
|
||||
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
||||
"""
|
||||
noise_aug_config = original_config.model.params.noise_aug_config
|
||||
noise_aug_class = noise_aug_config.target
|
||||
noise_aug_class = noise_aug_class.split(".")[-1]
|
||||
|
||||
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
||||
noise_aug_config = noise_aug_config.params
|
||||
embedding_dim = noise_aug_config.timestep_dim
|
||||
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
||||
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
||||
|
||||
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
||||
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
||||
|
||||
if "clip_stats_path" in noise_aug_config:
|
||||
if clip_stats_path is None:
|
||||
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
||||
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
||||
clip_mean = clip_mean[None, :]
|
||||
clip_std = clip_std[None, :]
|
||||
|
||||
clip_stats_state_dict = {
|
||||
"mean": clip_mean,
|
||||
"std": clip_std,
|
||||
}
|
||||
|
||||
image_normalizer.load_state_dict(clip_stats_state_dict)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
||||
|
||||
return image_normalizer, image_noising_scheduler
|
||||
|
||||
|
||||
def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs):
|
||||
if "num_in_channels" in kwargs:
|
||||
num_in_channels = kwargs.pop("num_in_channels")
|
||||
num_in_channels = kwargs.get("num_in_channels")
|
||||
|
||||
elif pipeline_class_name in [
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline"]:
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
]:
|
||||
num_in_channels = 9
|
||||
|
||||
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
num_in_channels = 7
|
||||
|
||||
else:
|
||||
num_in_channels = 4
|
||||
|
||||
if "upcast_attention" in kwargs:
|
||||
upcast_attention = kwargs.pop("upcast_attention")
|
||||
|
||||
upcast_attention = kwargs.get("upcast_attention", False)
|
||||
extract_ema = kwargs.get("extract_ema", False)
|
||||
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
@@ -1092,9 +1277,8 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
|
||||
return unet
|
||||
|
||||
|
||||
def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs):
|
||||
def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
@@ -1109,6 +1293,269 @@ def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model
|
||||
return vae
|
||||
|
||||
|
||||
def create_text_encoder_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
):
|
||||
model_type = infer_model_type(pipeline_class_name, original_config)
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
|
||||
try:
|
||||
text_encoder = convert_open_clip_checkpoint(
|
||||
checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
config_name, subfolder="tokenizer", local_files_only=local_files_only
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'."
|
||||
)
|
||||
else:
|
||||
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
|
||||
|
||||
elif model_type == "FrozenCLIPEmbedder":
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
text_encoder = convert_ldm_clip_checkpoint(
|
||||
checkpoint, local_files_only=local_files_only, text_encoder=None
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
config_name, subfolder="tokenizer", local_files_only=local_files_only
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
|
||||
)
|
||||
else:
|
||||
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
|
||||
|
||||
elif model_type == "SDXL-Refiner":
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
prefix = "conditioner.embedders.0.model."
|
||||
|
||||
try:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix=prefix,
|
||||
has_projection=True,
|
||||
local_files_only=local_files_only,
|
||||
**config_kwargs,
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
|
||||
)
|
||||
|
||||
else:
|
||||
return {
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
}
|
||||
|
||||
elif model_type == "SDXL":
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
|
||||
try:
|
||||
config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
prefix = "conditioner.embedders.1.model."
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix=prefix,
|
||||
has_projection=True,
|
||||
local_files_only=local_files_only,
|
||||
**config_kwargs,
|
||||
)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
|
||||
)
|
||||
|
||||
return {
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def create_scheduler_component(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
scheduler_config = get_default_scheduler_config()
|
||||
model_type = infer_model_type(pipeline_class_name, original_config)
|
||||
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
scheduler_config["num_train_timesteps"] = num_train_timesteps
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
# as it relies on a brittle global step parameter here
|
||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||
|
||||
else:
|
||||
prediction_type = prediction_type or "epsilon"
|
||||
|
||||
scheduler_config["prediction_type"] = prediction_type
|
||||
|
||||
if model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
scheduler_type = "euler"
|
||||
|
||||
else:
|
||||
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
|
||||
scheduler_config["beta_start"] = beta_start
|
||||
scheduler_config["beta_end"] = beta_end
|
||||
scheduler_config["beta_schedule"] = "scaled_linear"
|
||||
scheduler_config["clip_sample"] = False
|
||||
scheduler_config["set_alpha_to_one"] = False
|
||||
|
||||
scheduler_type = "ddim"
|
||||
|
||||
if scheduler_type == "pndm":
|
||||
scheduler_config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "heun":
|
||||
scheduler = HeunDiscreteScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
|
||||
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = DDIMScheduler.from_config(scheduler_config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
|
||||
return {"scheduler": scheduler}
|
||||
|
||||
|
||||
def create_stable_unclip_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
):
|
||||
components = {}
|
||||
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
clip_stats_path = kwargs.get("clip_stats_path", None)
|
||||
|
||||
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
||||
original_config,
|
||||
clip_stats_path=clip_stats_path,
|
||||
)
|
||||
|
||||
if pipeline_class_name == "StableUnCLIPPipeline":
|
||||
stable_unclip_prior = kwargs.get("stable_unclip_prior", None)
|
||||
if stable_unclip_prior is None and stable_unclip_prior != "karlo":
|
||||
raise NotImplementedError(f"Unknown prior for Stable UnCLIP model: {stable_unclip_prior}")
|
||||
|
||||
try:
|
||||
config_name = "kakaobrain/karlo-v1-alpha"
|
||||
prior = PriorTransformer.from_pretrained(config_name, subfolder="prior", local_files_only=local_files_only)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the prior in the following path: '{config_name}'."
|
||||
)
|
||||
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
prior_tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
prior_text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
config_name, local_files_only=local_files_only
|
||||
)
|
||||
prior_scheduler = DDPMScheduler.from_pretrained(
|
||||
config_name, subfolder="prior_scheduler", local_files_only=local_files_only
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"prior": prior,
|
||||
"prior_tokenizer": prior_tokenizer,
|
||||
"prior_text_encoder": prior_text_encoder,
|
||||
"prior_scheduler": prior_scheduler,
|
||||
"image_normalizer": image_normalizer,
|
||||
"image_noise_scheduler": image_noising_scheduler,
|
||||
}
|
||||
|
||||
else:
|
||||
feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
|
||||
|
||||
return {
|
||||
"feature_extractor": feature_extractor,
|
||||
"image_encoder": image_encoder,
|
||||
"image_normalizer": image_normalizer,
|
||||
"image_noising_scheduler": image_noising_scheduler,
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
|
||||
def create_paint_by_example_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
):
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
image_encoder = convert_paint_by_example_checkpoint(checkpoint)
|
||||
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
|
||||
try:
|
||||
config_name = "CompVis/stable-diffusion-safety-checker"
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'."
|
||||
)
|
||||
|
||||
return {
|
||||
"image_encoder": image_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": feature_extractor,
|
||||
}
|
||||
|
||||
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -1137,7 +1584,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
tokenizer=None,
|
||||
tokenizer_2=None,
|
||||
config_files=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
@@ -1238,7 +1685,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
original_config = fetch_original_config(checkpoint, config_files)
|
||||
model_type = set_model_type(original_config, model_type)
|
||||
model_type = infer_model_type(original_config, model_type)
|
||||
|
||||
unet = create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs)
|
||||
vae = create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs)
|
||||
@@ -1696,4 +2143,3 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
Reference in New Issue
Block a user