mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
CogView4 Control Block (#10809)
* cogview4 control training --------- Co-authored-by: OleehyO <leehy0357@gmail.com> Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
@@ -53,8 +53,18 @@ args = parser.parse_args()
|
||||
# this is specific to `AdaLayerNormContinuous`:
|
||||
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
"""
|
||||
Swap the scale and shift components in the weight tensor.
|
||||
|
||||
Args:
|
||||
weight (torch.Tensor): The original weight tensor.
|
||||
dim (int): The dimension along which to split.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The modified weight tensor with scale and shift swapped.
|
||||
"""
|
||||
shift, scale = weight.chunk(2, dim=dim)
|
||||
new_weight = torch.cat([scale, shift], dim=dim)
|
||||
return new_weight
|
||||
|
||||
|
||||
@@ -200,6 +210,7 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
|
||||
@@ -25,9 +25,15 @@ import argparse
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
|
||||
from transformers import GlmModel, PreTrainedTokenizerFast
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
CogView4Transformer2DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
|
||||
|
||||
@@ -112,6 +118,12 @@ parser.add_argument(
|
||||
default=128,
|
||||
help="Maximum size for positional embeddings.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to use control model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -150,13 +162,15 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
Returns:
|
||||
dict: The converted state dictionary compatible with Diffusers.
|
||||
"""
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
||||
mega = ckpt["model"]
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Patch Embedding
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
|
||||
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
|
||||
hidden_size, 128 if args.control else 64
|
||||
)
|
||||
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
|
||||
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
|
||||
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
|
||||
@@ -189,14 +203,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
|
||||
# AdaLayerNorm
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
|
||||
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
|
||||
)
|
||||
|
||||
# QKV
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
|
||||
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
|
||||
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
|
||||
|
||||
@@ -221,7 +229,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
# Attention Output
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.weight"
|
||||
].T
|
||||
]
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.bias"
|
||||
]
|
||||
@@ -252,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
||||
Returns:
|
||||
dict: The converted VAE state dictionary compatible with Diffusers.
|
||||
"""
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
|
||||
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
||||
|
||||
|
||||
@@ -286,7 +294,7 @@ def main(args):
|
||||
)
|
||||
transformer = CogView4Transformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
in_channels=32 if args.control else 16,
|
||||
num_layers=args.num_layers,
|
||||
attention_head_dim=args.attention_head_dim,
|
||||
num_attention_heads=args.num_heads,
|
||||
@@ -317,6 +325,7 @@ def main(args):
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 1.0,
|
||||
"shift_factor": 0.0,
|
||||
"force_upcast": True,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
@@ -331,7 +340,7 @@ def main(args):
|
||||
# Load the text encoder and tokenizer
|
||||
text_encoder_id = "THUDM/glm-4-9b-hf"
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
|
||||
text_encoder = GlmForCausalLM.from_pretrained(
|
||||
text_encoder = GlmModel.from_pretrained(
|
||||
text_encoder_id,
|
||||
cache_dir=args.text_encoder_cache_dir,
|
||||
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
|
||||
@@ -345,13 +354,22 @@ def main(args):
|
||||
)
|
||||
|
||||
# Create the pipeline
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
if args.control:
|
||||
pipe = CogView4ControlPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipe = CogView4Pipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
# Save the converted pipeline
|
||||
pipe.save_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user