mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
mixed inference
This commit is contained in:
@@ -313,7 +313,7 @@ def create_vae_diffusers_config(vae_params):
|
||||
return config
|
||||
|
||||
|
||||
def create_diffusers_schedular(original_config):
|
||||
def create_diffusers_scheduler(original_config):
|
||||
schedular = DDIMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
@@ -323,16 +323,6 @@ def create_diffusers_schedular(original_config):
|
||||
return schedular
|
||||
|
||||
|
||||
def create_ldm_bert_config(original_config):
|
||||
bert_params = original_config.model.parms.cond_stage_config.params
|
||||
config = LDMBertConfig(
|
||||
d_model=bert_params.n_embed,
|
||||
encoder_layers=bert_params.n_layer,
|
||||
encoder_ffn_dim=bert_params.n_embed * 4,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
@@ -503,7 +493,7 @@ def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
def convert_vd_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -608,72 +598,6 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||
|
||||
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||
|
||||
def _copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
def _copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||
|
||||
# copy attn
|
||||
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||
|
||||
# copy MLP
|
||||
pt_mlp = pt_layer[1][1]
|
||||
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||
|
||||
def _copy_layers(hf_layers, pt_layers):
|
||||
for i, hf_layer in enumerate(hf_layers):
|
||||
if i != 0:
|
||||
i += i
|
||||
pt_layer = pt_layers[i : i + 2]
|
||||
_copy_layer(hf_layer, pt_layer)
|
||||
|
||||
hf_model = LDMBertModel(config).eval()
|
||||
|
||||
# copy embeds
|
||||
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||
|
||||
# copy layer norm
|
||||
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||
|
||||
# copy hidden layers
|
||||
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||
|
||||
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -758,18 +682,18 @@ if __name__ == "__main__":
|
||||
image_unet = UNet2DConditionModel(**image_unet_config)
|
||||
image_unet.load_state_dict(converted_image_unet_checkpoint)
|
||||
|
||||
# text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
|
||||
# converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
# checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
# )
|
||||
# text_unet = UNet2DConditionModel(**text_unet_config)
|
||||
# text_unet.load_state_dict(converted_text_unet_checkpoint)
|
||||
text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
|
||||
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
)
|
||||
text_unet = UNet2DConditionModel(**text_unet_config)
|
||||
text_unet.load_state_dict(converted_text_unet_checkpoint, strict=False)
|
||||
|
||||
# Convert the VAE model.
|
||||
if args.vae_checkpoint_path is not None:
|
||||
vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG)
|
||||
checkpoint = torch.load(args.vae_checkpoint_path)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
@@ -786,7 +710,7 @@ if __name__ == "__main__":
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_unet=image_unet,
|
||||
# text_unet=text_unet,
|
||||
text_unet=text_unet,
|
||||
vae=vae,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
Reference in New Issue
Block a user