mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -17,11 +17,10 @@
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
import yaml
|
||||
from safetensors.torch import load_file as safe_load
|
||||
from transformers import (
|
||||
BertTokenizerFast,
|
||||
@@ -30,14 +29,11 @@ from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionConfig,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from ..models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
|
||||
from ..models import AutoencoderKL, UNet2DConditionModel
|
||||
from ..pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..pipelines.paint_by_example import PaintByExampleImageEncoder
|
||||
from ..pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
from ..schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -85,6 +81,53 @@ SCHEDULER_DEFAULT_CONFIG = {
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
|
||||
DIFFUSERS_TO_LDM_MAPPING = {
|
||||
"unet": {
|
||||
"time_embedding.linear_1.weight": "time_embed.0.weight",
|
||||
"time_embedding.linear_1.bias": "time_embed.0.bias",
|
||||
"time_embedding.linear_2.weight": "time_embed.2.weight",
|
||||
"time_embedding.linear_2.bias": "time_embed.2.bias",
|
||||
"conv_in.weight": "input_blocks.0.0.weight",
|
||||
"conv_in.bias": "input_blocks.0.0.bias",
|
||||
"class_embed_type": {
|
||||
"timestep": {
|
||||
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
|
||||
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
|
||||
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
|
||||
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
|
||||
},
|
||||
"text_time": {
|
||||
"class_embedding.linear_1.weight": "label_emb.0.0.weight",
|
||||
"class_embedding.linear_1.bias": "label_emb.0.0.bias",
|
||||
"class_embedding.linear_2.weight": "label_emb.0.2.weight",
|
||||
"class_embedding.linear_2.bias": "label_emb.0.2.bias",
|
||||
},
|
||||
},
|
||||
},
|
||||
"vae": {
|
||||
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
||||
"encoder.conv_in.bias": "encoder.conv_in.bias",
|
||||
"encoder.conv_out.weight": "encoder.conv_out.weight",
|
||||
"encoder.conv_out.bias": "encoder.conv_out.bias",
|
||||
"encoder.conv_norm_out.weight": "encoder.conv_norm_out.weight",
|
||||
"encoder.conv_norm_out.bias": "encoder.conv_norm_out.bias",
|
||||
"decoder.conv_in.weight": "decoder.conv_in.weight",
|
||||
"decoder.conv_in.bias": "decoder.conv_in.bias",
|
||||
"decoder.conv_out.weight": "decoder.conv_out.weight",
|
||||
"decoder.conv_out.bias": "decoder.conv_out.bias",
|
||||
"decoder.conv_norm_out.weight": "decoder.conv_norm_out.weight",
|
||||
"decoder.conv_norm_out.bias": "decoder.conv_norm_out.bias",
|
||||
"quant_conv.weight": "quant_conv.weight",
|
||||
"quant_conv.bias": "quant_conv.bias",
|
||||
"post_quant_conv.weight": "post_quant_conv.weight",
|
||||
"post_quant_conv.bias": "post_quant_conv.bias",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
UNET_TIME_EMBEDDING_LAYERS = []
|
||||
|
||||
|
||||
textenc_conversion_lst = [
|
||||
("positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
@@ -147,7 +190,7 @@ def fetch_original_config_file_from_file(config_files: list):
|
||||
|
||||
def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=None, config_files=None):
|
||||
if original_config_file:
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
return original_config
|
||||
|
||||
elif config_files:
|
||||
@@ -156,7 +199,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file=
|
||||
else:
|
||||
original_config_file = fetch_original_config_file_from_url(pipeline_class_name, checkpoint)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
original_config = yaml.safe_load(original_config_file)
|
||||
|
||||
return original_config
|
||||
|
||||
@@ -187,18 +230,19 @@ def infer_model_type(pipeline_class_name, original_config, model_type=None, **kw
|
||||
return model_type
|
||||
|
||||
has_cond_stage_config = (
|
||||
"cond_stage_config" in original_config.model.params
|
||||
and original_config.model.params.cond_stage_config is not None
|
||||
"cond_stage_config" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["cond_stage_config"] is not None
|
||||
)
|
||||
has_network_config = (
|
||||
"network_config" in original_config.model.params and original_config.model.params.network_config is not None
|
||||
"network_config" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["network_config"] is not None
|
||||
)
|
||||
|
||||
if has_cond_stage_config:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
model_type = original_config["model"]["params"]["cond_stage_config"]["target"].split(".")[-1]
|
||||
|
||||
elif has_network_config:
|
||||
context_dim = original_config.model.params.network_config.params.context_dim
|
||||
context_dim = original_config["model"]["params"]["network_config"]["params"]["context_dim"]
|
||||
if context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
@@ -221,7 +265,7 @@ def determine_image_size(pipeline_class_name, original_config, checkpoint, **kwa
|
||||
model_type = infer_model_type(pipeline_class_name, original_config, **kwargs)
|
||||
|
||||
if pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
image_size = original_config.model.params.unet_config.params.image_size
|
||||
image_size = original_config["model"]["params"].unet_config.params.image_size
|
||||
return image_size
|
||||
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
@@ -413,57 +457,55 @@ def conv_attn_to_linear(checkpoint):
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_unet_diffusers_config
|
||||
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
||||
def create_unet_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
if controlnet:
|
||||
unet_params = original_config.model.params.control_stage_config.params
|
||||
if (
|
||||
"unet_config" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["unet_config"] is not None
|
||||
):
|
||||
unet_params = original_config["model"]["params"]["unet_config"]["params"]
|
||||
else:
|
||||
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
|
||||
unet_params = original_config.model.params.unet_config.params
|
||||
else:
|
||||
unet_params = original_config.model.params.network_config.params
|
||||
unet_params = original_config["model"]["params"]["network_config"]["params"]
|
||||
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
|
||||
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
block_out_channels = [unet_params["model_channels"] * mult for mult in unet_params["channel_mult"]]
|
||||
|
||||
down_block_types = []
|
||||
resolution = 1
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
||||
block_type = "CrossAttnDownBlock2D" if resolution in unet_params["attention_resolutions"] else "DownBlock2D"
|
||||
down_block_types.append(block_type)
|
||||
if i != len(block_out_channels) - 1:
|
||||
resolution *= 2
|
||||
|
||||
up_block_types = []
|
||||
for i in range(len(block_out_channels)):
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
||||
block_type = "CrossAttnUpBlock2D" if resolution in unet_params["attention_resolutions"] else "UpBlock2D"
|
||||
up_block_types.append(block_type)
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
if unet_params["transformer_depth"] is not None:
|
||||
transformer_layers_per_block = (
|
||||
unet_params.transformer_depth
|
||||
if isinstance(unet_params.transformer_depth, int)
|
||||
else list(unet_params.transformer_depth)
|
||||
unet_params["transformer_depth"]
|
||||
if isinstance(unet_params["transformer_depth"], int)
|
||||
else list(unet_params["transformer_depth"])
|
||||
)
|
||||
else:
|
||||
transformer_layers_per_block = 1
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
||||
vae_scale_factor = 2 ** (len(vae_params["ch_mult"]) - 1)
|
||||
|
||||
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
||||
head_dim = unet_params["num_heads"] if "num_heads" in unet_params else None
|
||||
use_linear_projection = (
|
||||
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
||||
unet_params["use_linear_in_transformer"] if "use_linear_in_transformer" in unet_params else False
|
||||
)
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
head_dim_mult = unet_params["model_channels"] // unet_params["num_head_channels"]
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params["channel_mult"])]
|
||||
|
||||
class_embed_type = None
|
||||
addition_embed_type = None
|
||||
@@ -471,13 +513,15 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
projection_class_embeddings_input_dim = None
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
if unet_params["context_dim"] is not None:
|
||||
context_dim = (
|
||||
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
unet_params["context_dim"]
|
||||
if isinstance(unet_params["context_dim"], int)
|
||||
else unet_params["context_dim"][0]
|
||||
)
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
if unet_params["num_classes"] == "sequential":
|
||||
if context_dim in [2048, 1280]:
|
||||
# SDXL
|
||||
addition_embed_type = "text_time"
|
||||
@@ -485,14 +529,14 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
else:
|
||||
class_embed_type = "projection"
|
||||
assert "adm_in_channels" in unet_params
|
||||
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
||||
projection_class_embeddings_input_dim = unet_params["adm_in_channels"]
|
||||
|
||||
config = {
|
||||
"sample_size": image_size // vae_scale_factor,
|
||||
"in_channels": unet_params.in_channels,
|
||||
"in_channels": unet_params["in_channels"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"layers_per_block": unet_params["num_res_blocks"],
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
@@ -504,49 +548,42 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
}
|
||||
|
||||
if "disable_self_attentions" in unet_params:
|
||||
config["only_cross_attention"] = unet_params.disable_self_attentions
|
||||
config["only_cross_attention"] = unet_params["disable_self_attentions"]
|
||||
|
||||
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
|
||||
config["num_class_embeds"] = unet_params.num_classes
|
||||
if "num_classes" in unet_params and isinstance(unet_params["num_classes"], int):
|
||||
config["num_class_embeds"] = unet_params["num_classes"]
|
||||
|
||||
if controlnet:
|
||||
config["conditioning_channels"] = unet_params.hint_channels
|
||||
else:
|
||||
config["out_channels"] = unet_params.out_channels
|
||||
config["up_block_types"] = tuple(up_block_types)
|
||||
config["out_channels"] = unet_params["out_channels"]
|
||||
config["up_block_types"] = tuple(up_block_types)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.create_vae_diffusers_config
|
||||
def create_vae_diffusers_config(original_config, image_size: int):
|
||||
"""
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
||||
_ = original_config.model.params.first_stage_config.params.embed_dim
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
|
||||
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
"in_channels": vae_params.in_channels,
|
||||
"out_channels": vae_params.out_ch,
|
||||
"in_channels": vae_params["in_channels"],
|
||||
"out_channels": vae_params["out_ch"],
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"up_block_types": tuple(up_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"latent_channels": vae_params.z_channels,
|
||||
"layers_per_block": vae_params.num_res_blocks,
|
||||
"latent_channels": vae_params["z_channels"],
|
||||
"layers_per_block": vae_params["num_res_blocks"],
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_unet_checkpoint
|
||||
def convert_ldm_unet_checkpoint(
|
||||
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
|
||||
):
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, skip_extract_state_dict=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
@@ -558,10 +595,7 @@ def convert_ldm_unet_checkpoint(
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
if controlnet:
|
||||
unet_key = "control_model."
|
||||
else:
|
||||
unet_key = "model.diffusion_model."
|
||||
unet_key = "model.diffusion_model."
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
@@ -617,12 +651,10 @@ def convert_ldm_unet_checkpoint(
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
|
||||
if not controlnet:
|
||||
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
||||
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
||||
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
||||
# Retrieves the keys for the input blocks only
|
||||
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
||||
input_blocks = {
|
||||
@@ -747,48 +779,6 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
new_checkpoint[new_path] = unet_state_dict[old_path]
|
||||
|
||||
if controlnet:
|
||||
# conditioning embedding
|
||||
|
||||
orig_index = 0
|
||||
|
||||
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
|
||||
orig_index += 2
|
||||
|
||||
diffusers_index = 0
|
||||
|
||||
while diffusers_index < 6:
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
diffusers_index += 1
|
||||
orig_index += 2
|
||||
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
|
||||
# down blocks
|
||||
for i in range(num_input_blocks):
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
||||
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
||||
|
||||
# mid block
|
||||
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
||||
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
@@ -824,13 +814,13 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
||||
num_down_blocks = len(config["down_block_types"])
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
||||
num_up_blocks = len(config["up_block_types"])
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
@@ -1082,7 +1072,7 @@ def stable_unclip_image_encoder(original_config, local_files_only=False):
|
||||
encoders.
|
||||
"""
|
||||
|
||||
image_embedder_config = original_config.model.params.embedder_config
|
||||
image_embedder_config = original_config["model"]["params"].embedder_config
|
||||
|
||||
sd_clip_image_embedder_class = image_embedder_config.target
|
||||
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
||||
@@ -1111,120 +1101,8 @@ def stable_unclip_image_encoder(original_config, local_files_only=False):
|
||||
return feature_extractor, image_encoder
|
||||
|
||||
|
||||
def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False):
|
||||
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only)
|
||||
model = PaintByExampleImageEncoder(config)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
# load clip vision
|
||||
model.model.load_state_dict(text_model_dict)
|
||||
|
||||
# load mapper
|
||||
keys_mapper = {
|
||||
k[len("cond_stage_model.mapper.res") :]: v
|
||||
for k, v in checkpoint.items()
|
||||
if k.startswith("cond_stage_model.mapper")
|
||||
}
|
||||
|
||||
MAPPING = {
|
||||
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
||||
"attn.c_proj": ["attn1.to_out.0"],
|
||||
"ln_1": ["norm1"],
|
||||
"ln_2": ["norm3"],
|
||||
"mlp.c_fc": ["ff.net.0.proj"],
|
||||
"mlp.c_proj": ["ff.net.2"],
|
||||
}
|
||||
|
||||
mapped_weights = {}
|
||||
for key, value in keys_mapper.items():
|
||||
prefix = key[: len("blocks.i")]
|
||||
suffix = key.split(prefix)[-1].split(".")[-1]
|
||||
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
||||
mapped_names = MAPPING[name]
|
||||
|
||||
num_splits = len(mapped_names)
|
||||
for i, mapped_name in enumerate(mapped_names):
|
||||
new_name = ".".join([prefix, mapped_name, suffix])
|
||||
shape = value.shape[0] // num_splits
|
||||
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
||||
|
||||
model.mapper.load_state_dict(mapped_weights)
|
||||
|
||||
# load final layer norm
|
||||
model.final_layer_norm.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
||||
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load final proj
|
||||
model.proj_out.load_state_dict(
|
||||
{
|
||||
"bias": checkpoint["proj_out.bias"],
|
||||
"weight": checkpoint["proj_out.weight"],
|
||||
}
|
||||
)
|
||||
|
||||
# load uncond vector
|
||||
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
||||
return model
|
||||
|
||||
|
||||
def stable_unclip_image_noising_components(
|
||||
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Returns the noising components for the img2img and txt2img unclip pipelines.
|
||||
|
||||
Converts the stability noise augmentor into
|
||||
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
||||
2. a `DDPMScheduler` for holding the noise schedule
|
||||
|
||||
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
||||
"""
|
||||
noise_aug_config = original_config.model.params.noise_aug_config
|
||||
noise_aug_class = noise_aug_config.target
|
||||
noise_aug_class = noise_aug_class.split(".")[-1]
|
||||
|
||||
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
||||
noise_aug_config = noise_aug_config.params
|
||||
embedding_dim = noise_aug_config.timestep_dim
|
||||
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
||||
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
||||
|
||||
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
||||
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
||||
|
||||
if "clip_stats_path" in noise_aug_config:
|
||||
if clip_stats_path is None:
|
||||
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
||||
|
||||
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
||||
clip_mean = clip_mean[None, :]
|
||||
clip_std = clip_std[None, :]
|
||||
|
||||
clip_stats_state_dict = {
|
||||
"mean": clip_mean,
|
||||
"std": clip_std,
|
||||
}
|
||||
|
||||
image_normalizer.load_state_dict(clip_stats_state_dict)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
||||
|
||||
return image_normalizer, image_noising_scheduler
|
||||
|
||||
|
||||
def create_ldm_bert_config(original_config):
|
||||
bert_params = original_config.model.params.cond_stage_config.params
|
||||
bert_params = original_config["model"]["params"].cond_stage_config.params
|
||||
config = LDMBertConfig(
|
||||
d_model=bert_params.n_embed,
|
||||
encoder_layers=bert_params.n_layer,
|
||||
@@ -1416,7 +1294,7 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin
|
||||
prediction_type = kwargs.get("prediction_type", None)
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
num_train_timesteps = getattr(original_config["model"]["params"], "timesteps", None) or 1000
|
||||
scheduler_config["num_train_timesteps"] = num_train_timesteps
|
||||
|
||||
if (
|
||||
@@ -1437,8 +1315,8 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin
|
||||
scheduler_type = "euler"
|
||||
|
||||
else:
|
||||
beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
|
||||
beta_start = getattr(original_config["model"]["params"], "linear_start", None) or 0.02
|
||||
beta_end = getattr(original_config["model"]["params"], "linear_end", None) or 0.085
|
||||
scheduler_config["beta_start"] = beta_start
|
||||
scheduler_config["beta_end"] = beta_end
|
||||
scheduler_config["beta_schedule"] = "scaled_linear"
|
||||
@@ -1484,64 +1362,3 @@ def create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoin
|
||||
}
|
||||
|
||||
return {"scheduler": scheduler}
|
||||
|
||||
|
||||
def create_stable_unclip_components(
|
||||
pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs
|
||||
):
|
||||
local_files_only = kwargs.get("local_files_only", False)
|
||||
clip_stats_path = kwargs.get("clip_stats_path", None)
|
||||
|
||||
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
||||
original_config,
|
||||
clip_stats_path=clip_stats_path,
|
||||
)
|
||||
|
||||
if pipeline_class_name == "StableUnCLIPPipeline":
|
||||
stable_unclip_prior = kwargs.get("stable_unclip_prior", None)
|
||||
if stable_unclip_prior is None and stable_unclip_prior != "karlo":
|
||||
raise NotImplementedError(f"Unknown prior for Stable UnCLIP model: {stable_unclip_prior}")
|
||||
|
||||
try:
|
||||
config_name = "kakaobrain/karlo-v1-alpha"
|
||||
prior = PriorTransformer.from_pretrained(config_name, subfolder="prior", local_files_only=local_files_only)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the prior in the following path: '{config_name}'."
|
||||
)
|
||||
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
prior_tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
prior_text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
||||
config_name, local_files_only=local_files_only
|
||||
)
|
||||
prior_scheduler = DDPMScheduler.from_pretrained(
|
||||
config_name, subfolder="prior_scheduler", local_files_only=local_files_only
|
||||
)
|
||||
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'."
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"prior": prior,
|
||||
"prior_tokenizer": prior_tokenizer,
|
||||
"prior_text_encoder": prior_text_encoder,
|
||||
"prior_scheduler": prior_scheduler,
|
||||
"image_normalizer": image_normalizer,
|
||||
"image_noise_scheduler": image_noising_scheduler,
|
||||
}
|
||||
|
||||
else:
|
||||
feature_extractor, image_encoder = stable_unclip_image_encoder(original_config)
|
||||
|
||||
return {
|
||||
"feature_extractor": feature_extractor,
|
||||
"image_encoder": image_encoder,
|
||||
"image_normalizer": image_normalizer,
|
||||
"image_noising_scheduler": image_noising_scheduler,
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user