1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2024-01-19 07:47:50 +00:00
parent 2483d516ab
commit dab7f014a8

View File

@@ -169,7 +169,6 @@ DIFFUSERS_TO_LDM_MAPPING = {
LDM_VAE_KEY = "first_stage_model."
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_CONFIG_NAME = "openai/clip-vit-large-patch14"
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
@@ -237,23 +236,6 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
return original_config
def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(checkpoint_path_or_dict, str):
if from_safetensors:
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else:
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
elif isinstance(checkpoint_path_or_dict, dict):
checkpoint = checkpoint_path_or_dict
return checkpoint
def infer_model_type(original_config, model_type=None):
if model_type is not None:
return model_type
@@ -918,9 +900,9 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
return new_checkpoint
def create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=False):
def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=False):
try:
config = CLIPTextConfig.from_pretrained(LDM_CLIP_CONFIG_NAME, local_files_only=local_files_only)
config = CLIPTextConfig.from_pretrained(config_name, local_files_only=local_files_only)
except Exception:
raise ValueError(
f"With local_files_only set to {local_files_only}, you must first locally save the configuration in the following path: 'openai/clip-vit-large-patch14'."
@@ -1178,7 +1160,7 @@ def create_text_encoders_and_tokenizers_from_ldm(
try:
config_name = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_files_only=local_files_only)
except Exception:
raise ValueError(