1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

CogView4 Control Block (#10809)

* cogview4 control training


---------

Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
Yuxuan Zhang
2025-03-16 01:15:56 +08:00
committed by GitHub
parent cc19726f3d
commit 82188cef04
13 changed files with 2286 additions and 36 deletions

View File

@@ -53,8 +53,18 @@ args = parser.parse_args()
# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
"""
Swap the scale and shift components in the weight tensor.
Args:
weight (torch.Tensor): The original weight tensor.
dim (int): The dimension along which to split.
Returns:
torch.Tensor: The modified weight tensor with scale and shift swapped.
"""
shift, scale = weight.chunk(2, dim=dim)
new_weight = torch.cat([scale, shift], dim=dim)
return new_weight
@@ -200,6 +210,7 @@ def main(args):
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,

View File

@@ -25,9 +25,15 @@ import argparse
import torch
from tqdm import tqdm
from transformers import GlmForCausalLM, PreTrainedTokenizerFast
from transformers import GlmModel, PreTrainedTokenizerFast
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers import (
AutoencoderKL,
CogView4ControlPipeline,
CogView4Pipeline,
CogView4Transformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
@@ -112,6 +118,12 @@ parser.add_argument(
default=128,
help="Maximum size for positional embeddings.",
)
parser.add_argument(
"--control",
action="store_true",
default=False,
help="Whether to use control model.",
)
args = parser.parse_args()
@@ -150,13 +162,15 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
Returns:
dict: The converted state dictionary compatible with Diffusers.
"""
ckpt = torch.load(ckpt_path, map_location="cpu")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
mega = ckpt["model"]
new_state_dict = {}
# Patch Embedding
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64)
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
hidden_size, 128 if args.control else 64
)
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
@@ -189,14 +203,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
block_prefix = f"transformer_blocks.{i}."
# AdaLayerNorm
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.weight"], dim=0
)
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
)
# QKV
new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
@@ -221,7 +229,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
# Attention Output
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.weight"
].T
]
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.bias"
]
@@ -252,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
Returns:
dict: The converted VAE state dictionary compatible with Diffusers.
"""
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
@@ -286,7 +294,7 @@ def main(args):
)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=16,
in_channels=32 if args.control else 16,
num_layers=args.num_layers,
attention_head_dim=args.attention_head_dim,
num_attention_heads=args.num_heads,
@@ -317,6 +325,7 @@ def main(args):
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
@@ -331,7 +340,7 @@ def main(args):
# Load the text encoder and tokenizer
text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained(
text_encoder = GlmModel.from_pretrained(
text_encoder_id,
cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
@@ -345,13 +354,22 @@ def main(args):
)
# Create the pipeline
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
if args.control:
pipe = CogView4ControlPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
else:
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# Save the converted pipeline
pipe.save_pretrained(