mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
@@ -27,11 +28,11 @@ from ..utils import (
|
||||
)
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
from .single_file_utils import (
|
||||
create_controlnet_model,
|
||||
create_scheduler,
|
||||
create_text_encoders_and_tokenizers,
|
||||
create_unet_model,
|
||||
create_vae_model,
|
||||
create_diffusers_controlnet_model_from_ldm,
|
||||
create_diffusers_unet_model_from_ldm,
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
create_scheduler_from_ldm,
|
||||
create_text_encoders_and_tokenizers_from_ldm,
|
||||
fetch_original_config,
|
||||
infer_model_type,
|
||||
)
|
||||
@@ -96,46 +97,57 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
||||
return repo_id, weights_name
|
||||
|
||||
|
||||
def build_component(
|
||||
def build_sub_model_components(
|
||||
pipeline_components,
|
||||
pipeline_class_name,
|
||||
component_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
checkpoint_path_or_dict,
|
||||
local_files_only=False,
|
||||
load_safety_checker=False,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in kwargs:
|
||||
component = kwargs.pop(component_name, None)
|
||||
return {component_name: component}
|
||||
|
||||
if component_name in pipeline_components:
|
||||
return {}
|
||||
|
||||
load_safety_checker = kwargs.get("load_safety_checker", False)
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
model_type = kwargs.get("model_type", None)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
|
||||
if component_name == "unet":
|
||||
unet_components = create_unet_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
num_in_channels = kwargs.pop("num_in_channels", None)
|
||||
unet_components = create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, num_in_channels=num_in_channels, image_size=image_size
|
||||
)
|
||||
return unet_components
|
||||
|
||||
if component_name == "vae":
|
||||
vae_components = create_vae_model(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
vae_components = create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size
|
||||
)
|
||||
return vae_components
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler_components = create_scheduler(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
|
||||
scheduler_components = create_scheduler_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
scheduler_type=scheduler_type,
|
||||
prediction_type=prediction_type,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
return scheduler_components
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]:
|
||||
text_encoder_components = create_text_encoders_and_tokenizers(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
text_encoder_components = create_text_encoders_and_tokenizers_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=model_type,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return text_encoder_components
|
||||
|
||||
@@ -156,7 +168,7 @@ def build_component(
|
||||
return
|
||||
|
||||
|
||||
def build_additional_components(
|
||||
def set_additional_components(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
**kwargs,
|
||||
@@ -282,36 +294,57 @@ class FromSingleFileMixin:
|
||||
original_config = fetch_original_config(class_name, checkpoint, original_config_file, config_files)
|
||||
|
||||
if class_name == "AutoencoderKL":
|
||||
component = create_vae_model(class_name, original_config, checkpoint, pretrained_model_link_or_path)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(
|
||||
class_name, original_config, checkpoint, image_size=image_size
|
||||
)
|
||||
return component["vae"]
|
||||
|
||||
if class_name == "ControlNetModel":
|
||||
component = create_controlnet_model(class_name, original_config, checkpoint, **kwargs)
|
||||
upcast_attention = kwargs.pop("upcast_attention", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
|
||||
component = create_diffusers_controlnet_model_from_ldm(
|
||||
class_name, original_config, checkpoint, upcast_attention=upcast_attention, image_size=image_size
|
||||
)
|
||||
return component["controlnet"]
|
||||
|
||||
component_names = extract_pipeline_component_names(cls)
|
||||
pipeline_components = {}
|
||||
for component in component_names:
|
||||
components = build_component(
|
||||
pipeline_components,
|
||||
class_name,
|
||||
component,
|
||||
original_config,
|
||||
checkpoint,
|
||||
pretrained_model_link_or_path,
|
||||
**kwargs,
|
||||
)
|
||||
if not components:
|
||||
continue
|
||||
pipeline_components.update(components)
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config=None,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
additional_components = set(component_names - pipeline_components.keys())
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_kwargs = {}
|
||||
for name in expected_modules:
|
||||
if name in passed_class_obj:
|
||||
init_kwargs[name] = passed_class_obj[name]
|
||||
else:
|
||||
components = build_sub_model_components(
|
||||
init_kwargs,
|
||||
class_name,
|
||||
name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
pretrained_model_link_or_path,
|
||||
**kwargs,
|
||||
)
|
||||
if not components:
|
||||
continue
|
||||
init_kwargs.update(components)
|
||||
|
||||
additional_components = set(optional_kwargs - init_kwargs.keys())
|
||||
if additional_components:
|
||||
components = build_additional_components(class_name, original_config, **kwargs)
|
||||
components = set_additional_components(class_name, original_config, **kwargs)
|
||||
if components:
|
||||
pipeline_components.update(components)
|
||||
init_kwargs.update(components)
|
||||
|
||||
pipe = cls(**pipeline_components)
|
||||
init_kwargs.update(passed_pipe_kwargs)
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
""" Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
|
||||
@@ -188,30 +187,6 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
|
||||
"cond_stage_model.model.text_projection",
|
||||
]
|
||||
|
||||
textenc_conversion_lst = [
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||
("ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||
("text_projection", "text_projection.weight"),
|
||||
]
|
||||
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||
|
||||
textenc_transformer_conversion_lst = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("resblocks.", "text_model.encoder.layers."),
|
||||
("ln_1", "layer_norm1"),
|
||||
("ln_2", "layer_norm2"),
|
||||
(".c_fc.", ".fc1."),
|
||||
(".c_proj.", ".fc2."),
|
||||
(".attn", ".self_attn"),
|
||||
("ln_final.", "transformer.text_model.final_layer_norm."),
|
||||
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
||||
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
||||
]
|
||||
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
||||
textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def fetch_original_config_file_from_url(class_name, checkpoint):
|
||||
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
|
||||
@@ -284,7 +259,7 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True)
|
||||
return checkpoint
|
||||
|
||||
|
||||
def infer_model_type(pipeline_class_name, original_config, model_type=None, **kwargs):
|
||||
def infer_model_type(original_config, model_type=None):
|
||||
if model_type is not None:
|
||||
return model_type
|
||||
|
||||
@@ -318,10 +293,12 @@ def get_default_scheduler_config():
|
||||
return SCHEDULER_DEFAULT_CONFIG
|
||||
|
||||
|
||||
def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs):
|
||||
image_size = kwargs.get("image_size", 512)
|
||||
def set_image_size(pipeline_class_name, original_config, checkpoint, image_size=None, model_type=None):
|
||||
if image_size:
|
||||
return image_size
|
||||
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
model_type = infer_model_type(pipeline_class_name, original_config, **kwargs)
|
||||
model_type = infer_model_type(original_config, model_type)
|
||||
|
||||
if pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
image_size = original_config["model"]["params"].unet_config.params.image_size
|
||||
@@ -340,7 +317,9 @@ def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwa
|
||||
image_size = 512 if global_step == 875000 else 768
|
||||
return image_size
|
||||
|
||||
return image_size
|
||||
else:
|
||||
image_size = 512
|
||||
return image_size
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
@@ -526,41 +505,36 @@ def update_unet_attention_ldm_to_diffusers(ldm_keys, new_checkpoint, checkpoint,
|
||||
new_checkpoint[diffusers_key] = checkpoint.pop(ldm_key)
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False):
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, extract_ema=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
if skip_extract_state_dict:
|
||||
unet_state_dict = checkpoint
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
unet_key = LDM_UNET_KEY
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
logger.warning("Checkpoint has both EMA and non-EMA weights.")
|
||||
logger.warning(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
unet_key = LDM_UNET_KEY
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
logger.warning(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
logger.warning(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
ldm_unet_keys = DIFFUSERS_TO_LDM_MAPPING["unet"]["layers"]
|
||||
@@ -792,7 +766,10 @@ def convert_controlnet_checkpoint(
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **kwargs):
|
||||
def create_diffusers_controlnet_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, upcast_attention=False, image_size=None
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import ControlNetModel
|
||||
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
@@ -800,8 +777,7 @@ def create_controlnet_model(pipeline_class_name, original_config, checkpoint, **
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
|
||||
upcast_attention = kwargs.get("upcast_attention", False)
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
|
||||
diffusers_config = create_controlnet_diffusers_config(original_config, image_size=image_size)
|
||||
diffusers_config["upcast_attention"] = upcast_attention
|
||||
@@ -953,7 +929,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
|
||||
def create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=False):
|
||||
try:
|
||||
config = CLIPTextConfig.from_pretrained(LDM_CLIP_CONFIG_NAME, local_files_only=local_files_only)
|
||||
except Exception:
|
||||
@@ -988,7 +964,7 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
|
||||
return text_model
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(
|
||||
def create_text_encoder_from_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix="cond_stage_model.model.",
|
||||
@@ -1069,36 +1045,35 @@ def convert_open_clip_checkpoint(
|
||||
return text_model
|
||||
|
||||
|
||||
def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
if "num_in_channels" in kwargs:
|
||||
num_in_channels = kwargs.get("num_in_channels")
|
||||
def create_diffusers_unet_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
num_in_channels=None,
|
||||
upcast_attention=False,
|
||||
extract_ema=False,
|
||||
image_size=None,
|
||||
):
|
||||
if num_in_channels is None:
|
||||
if pipeline_class_name in [
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
]:
|
||||
num_in_channels = 9
|
||||
|
||||
elif pipeline_class_name in [
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
]:
|
||||
num_in_channels = 9
|
||||
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
num_in_channels = 7
|
||||
|
||||
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
num_in_channels = 7
|
||||
|
||||
else:
|
||||
num_in_channels = 4
|
||||
|
||||
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
|
||||
|
||||
upcast_attention = kwargs.get("upcast_attention", False)
|
||||
extract_ema = kwargs.get("extract_ema", False)
|
||||
else:
|
||||
num_in_channels = 4
|
||||
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["in_channels"] = num_in_channels
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=path, extract_ema=extract_ema
|
||||
)
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
@@ -1112,10 +1087,16 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi
|
||||
return {"unet": unet}
|
||||
|
||||
|
||||
def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
def create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size=None,
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import AutoencoderKL
|
||||
|
||||
image_size = determine_image_size(pipeline_class_name, original_config, checkpoint, **kwargs)
|
||||
image_size = set_image_size(pipeline_class_name, original_config, checkpoint, image_size=image_size)
|
||||
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
@@ -1133,18 +1114,20 @@ def create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoin
|
||||
return {"vae": vae}
|
||||
|
||||
|
||||
def create_text_encoders_and_tokenizers(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
def create_text_encoders_and_tokenizers_from_ldm(
|
||||
original_config,
|
||||
checkpoint,
|
||||
model_type=None,
|
||||
local_files_only=False,
|
||||
):
|
||||
model_type = infer_model_type(pipeline_class_name, original_config)
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
|
||||
try:
|
||||
text_encoder = convert_open_clip_checkpoint(
|
||||
text_encoder = create_text_encoder_from_open_clip_checkpoint(
|
||||
checkpoint, config_name, local_files_only=local_files_only, **config_kwargs
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
@@ -1160,7 +1143,7 @@ def create_text_encoders_and_tokenizers(
|
||||
elif model_type == "FrozenCLIPEmbedder":
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
|
||||
except Exception:
|
||||
@@ -1177,7 +1160,7 @@ def create_text_encoders_and_tokenizers(
|
||||
|
||||
try:
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix=prefix,
|
||||
@@ -1185,8 +1168,7 @@ def create_text_encoders_and_tokenizers(
|
||||
local_files_only=local_files_only,
|
||||
**config_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'."
|
||||
)
|
||||
@@ -1203,10 +1185,9 @@ def create_text_encoders_and_tokenizers(
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'."
|
||||
)
|
||||
@@ -1216,7 +1197,7 @@ def create_text_encoders_and_tokenizers(
|
||||
config_kwargs = {"projection_dim": 1280}
|
||||
prefix = "conditioner.embedders.1.model."
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only)
|
||||
text_encoder_2 = convert_open_clip_checkpoint(
|
||||
text_encoder_2 = create_text_encoder_from_open_clip_checkpoint(
|
||||
checkpoint,
|
||||
config_name,
|
||||
prefix=prefix,
|
||||
@@ -1239,12 +1220,17 @@ def create_text_encoders_and_tokenizers(
|
||||
return
|
||||
|
||||
|
||||
def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
def create_scheduler_from_ldm(
|
||||
pipeline_class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
prediction_type=None,
|
||||
scheduler_type="ddim",
|
||||
model_type=None,
|
||||
):
|
||||
scheduler_config = get_default_scheduler_config()
|
||||
model_type = infer_model_type(pipeline_class_name, original_config)
|
||||
model_type = infer_model_type(original_config, model_type=model_type)
|
||||
|
||||
scheduler_type = kwargs.get("scheduler_type", "ddim")
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
|
||||
|
||||
Reference in New Issue
Block a user