From cf560a715e7aae4b8268244a9d83c62d3eb0bdfd Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 15 Jan 2024 12:54:10 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file_utils.py | 405 ++++++--------------- 1 file changed, 111 insertions(+), 294 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index b8c2dcba34..6f71133b58 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -17,11 +17,10 @@ import re from contextlib import nullcontext from io import BytesIO -from typing import Optional import requests import torch -from omegaconf import OmegaConf +import yaml from safetensors.torch import load_file as safe_load from transformers import ( BertTokenizerFast, @@ -30,14 +29,11 @@ from transformers import ( CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, - CLIPVisionConfig, CLIPVisionModelWithProjection, ) -from ..models import AutoencoderKL, PriorTransformer, UNet2DConditionModel +from ..models import AutoencoderKL, UNet2DConditionModel from ..pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel -from ..pipelines.paint_by_example import PaintByExampleImageEncoder -from ..pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from ..schedulers import ( DDIMScheduler, DDPMScheduler, @@ -85,6 +81,53 @@ SCHEDULER_DEFAULT_CONFIG = { "timestep_spacing": "leading", } +DIFFUSERS_TO_LDM_MAPPING = { + "unet": { + "time_embedding.linear_1.weight": "time_embed.0.weight", + "time_embedding.linear_1.bias": "time_embed.0.bias", + "time_embedding.linear_2.weight": "time_embed.2.weight", + "time_embedding.linear_2.bias": "time_embed.2.bias", + "conv_in.weight": "input_blocks.0.0.weight", + "conv_in.bias": "input_blocks.0.0.bias", + "class_embed_type": { + "timestep": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + "text_time": { + "class_embedding.linear_1.weight": "label_emb.0.0.weight", + "class_embedding.linear_1.bias": "label_emb.0.0.bias", + "class_embedding.linear_2.weight": "label_emb.0.2.weight", + "class_embedding.linear_2.bias": "label_emb.0.2.bias", + }, + }, + }, + "vae": { + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_norm_out.weight": "encoder.conv_norm_out.weight", + "encoder.conv_norm_out.bias": "encoder.conv_norm_out.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_norm_out.weight": "decoder.conv_norm_out.weight", + "decoder.conv_norm_out.bias": "decoder.conv_norm_out.bias", + "quant_conv.weight": "quant_conv.weight", + "quant_conv.bias": "quant_conv.bias", + "post_quant_conv.weight": "post_quant_conv.weight", + "post_quant_conv.bias": "post_quant_conv.bias", + }, +} + + +UNET_TIME_EMBEDDING_LAYERS = [] + + textenc_conversion_lst = [ ("positional_embedding", "text_model.embeddings.position_embedding.weight"), ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), @@ -147,7 +190,7 @@ def fetch_original_config_file_from_file(config_files: list): def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None, config_files=None): if original_config_file: - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) return original_config elif config_files: @@ -156,7 +199,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file= else: original_config_file = fetch_original_config_file_from_url(pipeline_class_name, checkpoint) - original_config = OmegaConf.load(original_config_file) + original_config = yaml.safe_load(original_config_file) return original_config @@ -187,18 +230,19 @@ def infer_model_type(pipeline_class_name, original_config, model_type=None, **kw return model_type has_cond_stage_config = ( - "cond_stage_config" in original_config.model.params - and original_config.model.params.cond_stage_config is not None + "cond_stage_config" in original_config["model"]["params"] + and original_config["model"]["params"]["cond_stage_config"] is not None ) has_network_config = ( - "network_config" in original_config.model.params and original_config.model.params.network_config is not None + "network_config" in original_config["model"]["params"] + and original_config["model"]["params"]["network_config"] is not None ) if has_cond_stage_config: - model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] + model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1] elif has_network_config: - context_dim = original_config.model.params.network_config.params.context_dim + context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"] if context_dim == 2048: model_type = "SDXL" else: @@ -221,7 +265,7 @@ def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwa model_type = infer_model_type(pipeline_class_name, original_config, **kwargs) if pipeline_class_name == "StableDiffusionUpscalePipeline": - image_size = original_config.model.params.unet_config.params.image_size + image_size = original_config["model"]["params"].unet_config.params.image_size return image_size elif model_type in ["SDXL", "SDXL-Refiner"]: @@ -413,57 +457,55 @@ def conv_attn_to_linear(checkpoint): checkpoint[key] = checkpoint[key][:, :, 0] -# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_unet_diffusers_config -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): +def create_unet_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ - if controlnet: - unet_params = original_config.model.params.control_stage_config.params + if ( + "unet_config" in original_config["model"]["params"] + and original_config["model"]["params"]["unet_config"] is not None + ): + unet_params = original_config["model"]["params"]["unet_config"]["params"] else: - if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None: - unet_params = original_config.model.params.unet_config.params - else: - unet_params = original_config.model.params.network_config.params + unet_params = original_config["model"]["params"]["network_config"]["params"] - vae_params = original_config.model.params.first_stage_config.params.ddconfig - - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] + block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]] down_block_types = [] resolution = 1 for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D" down_block_types.append(block_type) if i != len(block_out_channels) - 1: resolution *= 2 up_block_types = [] for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D" up_block_types.append(block_type) resolution //= 2 - if unet_params.transformer_depth is not None: + if unet_params["transformer_depth"] is not None: transformer_layers_per_block = ( - unet_params.transformer_depth - if isinstance(unet_params.transformer_depth, int) - else list(unet_params.transformer_depth) + unet_params["transformer_depth"] + if isinstance(unet_params["transformer_depth"], int) + else list(unet_params["transformer_depth"]) ) else: transformer_layers_per_block = 1 - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1) - head_dim = unet_params.num_heads if "num_heads" in unet_params else None + head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False ) if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim_mult = unet_params.model_channels // unet_params.num_head_channels - head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] + head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"] + head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])] class_embed_type = None addition_embed_type = None @@ -471,13 +513,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa projection_class_embeddings_input_dim = None context_dim = None - if unet_params.context_dim is not None: + if unet_params["context_dim"] is not None: context_dim = ( - unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] + unet_params["context_dim"] + if isinstance(unet_params["context_dim"], int) + else unet_params["context_dim"][0] ) if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": + if unet_params["num_classes"] == "sequential": if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" @@ -485,14 +529,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa else: class_embed_type = "projection" assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels + projection_class_embeddings_input_dim = unet_params["adm_in_channels"] config = { "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, + "in_channels": unet_params["in_channels"], "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, + "layers_per_block": unet_params["num_res_blocks"], "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, @@ -504,49 +548,42 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa } if "disable_self_attentions" in unet_params: - config["only_cross_attention"] = unet_params.disable_self_attentions + config["only_cross_attention"] = unet_params["disable_self_attentions"] - if "num_classes" in unet_params and isinstance(unet_params.num_classes, int): - config["num_class_embeds"] = unet_params.num_classes + if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int): + config["num_class_embeds"] = unet_params["num_classes"] - if controlnet: - config["conditioning_channels"] = unet_params.hint_channels - else: - config["out_channels"] = unet_params.out_channels - config["up_block_types"] = tuple(up_block_types) + config["out_channels"] = unet_params["out_channels"] + config["up_block_types"] = tuple(up_block_types) return config -# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config def create_vae_diffusers_config(original_config, image_size: int): """ Creates a config for the diffusers based on the config of the LDM model. """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim + vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"] - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]] down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) config = { "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, + "in_channels": vae_params["in_channels"], + "out_channels": vae_params["out_ch"], "down_block_types": tuple(down_block_types), "up_block_types": tuple(up_block_types), "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, + "latent_channels": vae_params["z_channels"], + "layers_per_block": vae_params["num_res_blocks"], } + return config -# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint -def convert_ldm_unet_checkpoint( - checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False -): +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False): """ Takes a state dict and a config, and returns a converted checkpoint. """ @@ -558,10 +595,7 @@ def convert_ldm_unet_checkpoint( unet_state_dict = {} keys = list(checkpoint.keys()) - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." + unet_key = "model.diffusion_model." # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: @@ -617,12 +651,10 @@ def convert_ldm_unet_checkpoint( new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] # Retrieves the keys for the input blocks only num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) input_blocks = { @@ -747,48 +779,6 @@ def convert_ldm_unet_checkpoint( new_checkpoint[new_path] = unet_state_dict[old_path] - if controlnet: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - return new_checkpoint @@ -824,13 +814,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config): new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + num_down_blocks = len(config["down_block_types"]) down_blocks = { layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) } # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + num_up_blocks = len(config["up_block_types"]) up_blocks = { layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) } @@ -1082,7 +1072,7 @@ def stable_unclip_image_encoder(original_config, local_files_only=False): encoders. """ - image_embedder_config = original_config.model.params.embedder_config + image_embedder_config = original_config["model"]["params"].embedder_config sd_clip_image_embedder_class = image_embedder_config.target sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] @@ -1111,120 +1101,8 @@ def stable_unclip_image_encoder(original_config, local_files_only=False): return feature_extractor, image_encoder -def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): - config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config.model.params.noise_aug_config - noise_aug_class = noise_aug_config.target - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - def create_ldm_bert_config(original_config): - bert_params = original_config.model.params.cond_stage_config.params + bert_params = original_config["model"]["params"].cond_stage_config.params config = LDMBertConfig( d_model=bert_params.n_embed, encoder_layers=bert_params.n_layer, @@ -1416,7 +1294,7 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin prediction_type = kwargs.get("prediction_type", None) global_step = checkpoint["global_step"] if "global_step" in checkpoint else None - num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000 scheduler_config["num_train_timesteps"] = num_train_timesteps if ( @@ -1437,8 +1315,8 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin scheduler_type = "euler" else: - beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 - beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02 + beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085 scheduler_config["beta_start"] = beta_start scheduler_config["beta_end"] = beta_end scheduler_config["beta_schedule"] = "scaled_linear" @@ -1484,64 +1362,3 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin } return {"scheduler": scheduler} - - -def create_stable_unclip_components( - pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs -): - local_files_only = kwargs.get("local_files_only", False) - clip_stats_path = kwargs.get("clip_stats_path", None) - - image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( - original_config, - clip_stats_path=clip_stats_path, - ) - - if pipeline_class_name == "StableUnCLIPPipeline": - stable_unclip_prior = kwargs.get("stable_unclip_prior", None) - if stable_unclip_prior is None and stable_unclip_prior != "karlo": - raise NotImplementedError(f"Unknown prior for Stable UnCLIP model: {stable_unclip_prior}") - - try: - config_name = "kakaobrain/karlo-v1-alpha" - prior = PriorTransformer.from_pretrained(config_name, subfolder="prior", local_files_only=local_files_only) - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the prior in the following path: '{config_name}'." - ) - - try: - config_name = "openai/clip-vit-large-patch14" - prior_tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) - prior_text_encoder = CLIPTextModelWithProjection.from_pretrained( - config_name, local_files_only=local_files_only - ) - prior_scheduler = DDPMScheduler.from_pretrained( - config_name, subfolder="prior_scheduler", local_files_only=local_files_only - ) - - except Exception: - raise ValueError( - f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'." - ) - else: - return { - "prior": prior, - "prior_tokenizer": prior_tokenizer, - "prior_text_encoder": prior_text_encoder, - "prior_scheduler": prior_scheduler, - "image_normalizer": image_normalizer, - "image_noise_scheduler": image_noising_scheduler, - } - - else: - feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) - - return { - "feature_extractor": feature_extractor, - "image_encoder": image_encoder, - "image_normalizer": image_normalizer, - "image_noising_scheduler": image_noising_scheduler, - } - - return