1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-01-15 12:54:10 +00:00
parent cf2fe1eaa0
commit cf560a715e

View File

@@ -17,11 +17,10 @@
import re
from contextlib import nullcontext
from io import BytesIO
from typing import Optional
import requests
import torch
from omegaconf import OmegaConf
import yaml
from safetensors.torch import load_file as safe_load
from transformers import (
BertTokenizerFast,
@@ -30,14 +29,11 @@ from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from ..models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ..models import AutoencoderKL, UNet2DConditionModel
from ..pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..pipelines.paint_by_example import PaintByExampleImageEncoder
from ..pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
from ..schedulers import (
DDIMScheduler,
DDPMScheduler,
@@ -85,6 +81,53 @@ SCHEDULER_DEFAULT_CONFIG = {
"timestep_spacing": "leading",
}
DIFFUSERS_TO_LDM_MAPPING = {
"unet": {
"time_embedding.linear_1.weight": "time_embed.0.weight",
"time_embedding.linear_1.bias": "time_embed.0.bias",
"time_embedding.linear_2.weight": "time_embed.2.weight",
"time_embedding.linear_2.bias": "time_embed.2.bias",
"conv_in.weight": "input_blocks.0.0.weight",
"conv_in.bias": "input_blocks.0.0.bias",
"class_embed_type": {
"timestep": {
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
},
"text_time": {
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
},
},
},
"vae": {
"encoder.conv_in.weight": "encoder.conv_in.weight",
"encoder.conv_in.bias": "encoder.conv_in.bias",
"encoder.conv_out.weight": "encoder.conv_out.weight",
"encoder.conv_out.bias": "encoder.conv_out.bias",
"encoder.conv_norm_out.weight": "encoder.conv_norm_out.weight",
"encoder.conv_norm_out.bias": "encoder.conv_norm_out.bias",
"decoder.conv_in.weight": "decoder.conv_in.weight",
"decoder.conv_in.bias": "decoder.conv_in.bias",
"decoder.conv_out.weight": "decoder.conv_out.weight",
"decoder.conv_out.bias": "decoder.conv_out.bias",
"decoder.conv_norm_out.weight": "decoder.conv_norm_out.weight",
"decoder.conv_norm_out.bias": "decoder.conv_norm_out.bias",
"quant_conv.weight": "quant_conv.weight",
"quant_conv.bias": "quant_conv.bias",
"post_quant_conv.weight": "post_quant_conv.weight",
"post_quant_conv.bias": "post_quant_conv.bias",
},
}
UNET_TIME_EMBEDDING_LAYERS = []
textenc_conversion_lst = [
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
@@ -147,7 +190,7 @@ def fetch_original_config_file_from_file(config_files: list):
def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None, config_files=None):
if original_config_file:
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
return original_config
elif config_files:
@@ -156,7 +199,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
else:
original_config_file = fetch_original_config_file_from_url(pipeline_class_name, checkpoint)
original_config = OmegaConf.load(original_config_file)
original_config = yaml.safe_load(original_config_file)
return original_config
@@ -187,18 +230,19 @@ def infer_model_type(pipeline_class_name, original_config, model_type=None, **kw
return model_type
has_cond_stage_config = (
"cond_stage_config" in original_config.model.params
and original_config.model.params.cond_stage_config is not None
"cond_stage_config" in original_config["model"]["params"]
and original_config["model"]["params"]["cond_stage_config"] is not None
)
has_network_config = (
"network_config" in original_config.model.params and original_config.model.params.network_config is not None
"network_config" in original_config["model"]["params"]
and original_config["model"]["params"]["network_config"] is not None
)
if has_cond_stage_config:
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
elif has_network_config:
context_dim = original_config.model.params.network_config.params.context_dim
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
if context_dim == 2048:
model_type = "SDXL"
else:
@@ -221,7 +265,7 @@ def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwa
model_type = infer_model_type(pipeline_class_name, original_config, **kwargs)
if pipeline_class_name == "StableDiffusionUpscalePipeline":
image_size = original_config.model.params.unet_config.params.image_size
image_size = original_config["model"]["params"].unet_config.params.image_size
return image_size
elif model_type in ["SDXL", "SDXL-Refiner"]:
@@ -413,57 +457,55 @@ def conv_attn_to_linear(checkpoint):
checkpoint[key] = checkpoint[key][:, :, 0]
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_unet_diffusers_config
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
def create_unet_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
if (
"unet_config" in original_config["model"]["params"]
and original_config["model"]["params"]["unet_config"] is not None
):
unet_params = original_config["model"]["params"]["unet_config"]["params"]
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
else:
unet_params = original_config.model.params.network_config.params
unet_params = original_config["model"]["params"]["network_config"]["params"]
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
if unet_params["transformer_depth"] is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
unet_params["transformer_depth"]
if isinstance(unet_params["transformer_depth"], int)
else list(unet_params["transformer_depth"])
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
class_embed_type = None
addition_embed_type = None
@@ -471,13 +513,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
if unet_params["context_dim"] is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
unet_params["context_dim"]
if isinstance(unet_params["context_dim"], int)
else unet_params["context_dim"][0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
if unet_params["num_classes"] == "sequential":
if context_dim in [2048, 1280]:
# SDXL
addition_embed_type = "text_time"
@@ -485,14 +529,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
else:
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"in_channels": unet_params["in_channels"],
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"layers_per_block": unet_params["num_res_blocks"],
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
@@ -504,49 +548,42 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
config["only_cross_attention"] = unet_params["disable_self_attentions"]
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
config["num_class_embeds"] = unet_params["num_classes"]
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types)
config["out_channels"] = unet_params["out_channels"]
config["up_block_types"] = tuple(up_block_types)
return config
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
def create_vae_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"in_channels": vae_params["in_channels"],
"out_channels": vae_params["out_ch"],
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
"latent_channels": vae_params["z_channels"],
"layers_per_block": vae_params["num_res_blocks"],
}
return config
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
@@ -558,10 +595,7 @@ def convert_ldm_unet_checkpoint(
unet_state_dict = {}
keys = list(checkpoint.keys())
if controlnet:
unet_key = "control_model."
else:
unet_key = "model.diffusion_model."
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
@@ -617,12 +651,10 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
if not controlnet:
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
@@ -747,48 +779,6 @@ def convert_ldm_unet_checkpoint(
new_checkpoint[new_path] = unet_state_dict[old_path]
if controlnet:
# conditioning embedding
orig_index = 0
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
orig_index += 2
diffusers_index = 0
while diffusers_index < 6:
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
diffusers_index += 1
orig_index += 2
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
# mid block
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
return new_checkpoint
@@ -824,13 +814,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
num_down_blocks = len(config["down_block_types"])
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
num_up_blocks = len(config["up_block_types"])
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
@@ -1082,7 +1072,7 @@ def stable_unclip_image_encoder(original_config, local_files_only=False):
encoders.
"""
image_embedder_config = original_config.model.params.embedder_config
image_embedder_config = original_config["model"]["params"].embedder_config
sd_clip_image_embedder_class = image_embedder_config.target
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
@@ -1111,120 +1101,8 @@ def stable_unclip_image_encoder(original_config, local_files_only=False):
return feature_extractor, image_encoder
def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
model = PaintByExampleImageEncoder(config)
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# load clip vision
model.model.load_state_dict(text_model_dict)
# load mapper
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}
MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}
mapped_weights = {}
for key, value in keys_mapper.items():
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]
num_splits = len(mapped_names)
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
model.mapper.load_state_dict(mapped_weights)
# load final layer norm
model.final_layer_norm.load_state_dict(
{
"bias": checkpoint["cond_stage_model.final_ln.bias"],
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)
# load final proj
model.proj_out.load_state_dict(
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)
# load uncond vector
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
return model
def stable_unclip_image_noising_components(
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
):
"""
Returns the noising components for the img2img and txt2img unclip pipelines.
Converts the stability noise augmentor into
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
2. a `DDPMScheduler` for holding the noise schedule
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
"""
noise_aug_config = original_config.model.params.noise_aug_config
noise_aug_class = noise_aug_config.target
noise_aug_class = noise_aug_class.split(".")[-1]
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
noise_aug_config = noise_aug_config.params
embedding_dim = noise_aug_config.timestep_dim
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
if "clip_stats_path" in noise_aug_config:
if clip_stats_path is None:
raise ValueError("This stable unclip config requires a `clip_stats_path`")
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
clip_mean = clip_mean[None, :]
clip_std = clip_std[None, :]
clip_stats_state_dict = {
"mean": clip_mean,
"std": clip_std,
}
image_normalizer.load_state_dict(clip_stats_state_dict)
else:
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
return image_normalizer, image_noising_scheduler
def create_ldm_bert_config(original_config):
bert_params = original_config.model.params.cond_stage_config.params
bert_params = original_config["model"]["params"].cond_stage_config.params
config = LDMBertConfig(
d_model=bert_params.n_embed,
encoder_layers=bert_params.n_layer,
@@ -1416,7 +1294,7 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin
prediction_type = kwargs.get("prediction_type", None)
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
scheduler_config["num_train_timesteps"] = num_train_timesteps
if (
@@ -1437,8 +1315,8 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin
scheduler_type = "euler"
else:
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02
beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085
scheduler_config["beta_start"] = beta_start
scheduler_config["beta_end"] = beta_end
scheduler_config["beta_schedule"] = "scaled_linear"
@@ -1484,64 +1362,3 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin
}
return {"scheduler": scheduler}
def create_stable_unclip_components(
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
):
local_files_only = kwargs.get("local_files_only", False)
clip_stats_path = kwargs.get("clip_stats_path", None)
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config,
clip_stats_path=clip_stats_path,
)
if pipeline_class_name == "StableUnCLIPPipeline":
stable_unclip_prior = kwargs.get("stable_unclip_prior", None)
if stable_unclip_prior is None and stable_unclip_prior != "karlo":
raise NotImplementedError(f"Unknown prior for Stable UnCLIP model: {stable_unclip_prior}")
try:
config_name = "kakaobrain/karlo-v1-alpha"
prior = PriorTransformer.from_pretrained(config_name, subfolder="prior", local_files_only=local_files_only)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the prior in the following path: '{config_name}'."
)
try:
config_name = "openai/clip-vit-large-patch14"
prior_tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
prior_text_encoder = CLIPTextModelWithProjection.from_pretrained(
config_name, local_files_only=local_files_only
)
prior_scheduler = DDPMScheduler.from_pretrained(
config_name, subfolder="prior_scheduler", local_files_only=local_files_only
)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
)
else:
return {
"prior": prior,
"prior_tokenizer": prior_tokenizer,
"prior_text_encoder": prior_text_encoder,
"prior_scheduler": prior_scheduler,
"image_normalizer": image_normalizer,
"image_noise_scheduler": image_noising_scheduler,
}
else:
feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
return {
"feature_extractor": feature_extractor,
"image_encoder": image_encoder,
"image_normalizer": image_normalizer,
"image_noising_scheduler": image_noising_scheduler,
}
return