1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
zRzRzRzRzRzRzR
2026-01-07 16:55:02 +08:00
parent bcc9c303f6
commit e13fb76552
13 changed files with 93 additions and 310 deletions

View File

@@ -354,7 +354,7 @@
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageDecoderTransformer2DModel
title: GlmImageTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d

View File

@@ -9,10 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# GlmImageDecoderTransformer2DModel
# GlmImageTransformer2DModel
A Diffusion Transformer model for 2D data from [GlmImageDecoderTransformer2DModel]()
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]()
## GlmImageDecoderTransformer2DModel
## GlmImageTransformer2DModel
[[autodoc]] GlmImageDecoderTransformer2DModel
[[autodoc]] GlmImageTransformer2DModel

View File

@@ -20,12 +20,12 @@
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org).
## GlmImageDecoderPipeline
## GlmImagePipeline
[[autodoc]] GlmImageDecoderPipeline
[[autodoc]] GlmImagePipeline
- all
- __call__
## GlmImageDecoderPipelineOutput
## GlmImagePipelineOutput
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImageDecoderPipelineOutput
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput

View File

@@ -223,7 +223,7 @@ else:
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageDecoderTransformer2DModel",
"GlmImageTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -488,7 +488,7 @@ else:
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImageDecoderPipeline",
"GlmImagePipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -971,7 +971,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageDecoderTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -1206,7 +1206,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
GlmImageDecoderPipeline,
GlmImagePipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,

View File

@@ -96,7 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageDecoderTransformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -204,7 +204,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageDecoderTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,

View File

@@ -1658,7 +1658,7 @@ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
return conditioning
class GlmImageDecoderCombinedTimestepSizeEmbeddings(nn.Module):
class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()

View File

@@ -27,7 +27,7 @@ if is_torch_available():
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageDecoderTransformer2DModel
from .transformer_glm_image import GlmImageTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel

View File

@@ -27,7 +27,7 @@ from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import GlmImageDecoderCombinedTimestepSizeEmbeddings
from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LayerNorm, RMSNorm
@@ -36,7 +36,7 @@ from ..normalization import LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class GlmImageDecoderImageProjector(nn.Module):
class GlmImageImageProjector(nn.Module):
def __init__(
self,
in_channels: int = 16,
@@ -62,7 +62,7 @@ class GlmImageDecoderImageProjector(nn.Module):
return hidden_states
class GlmImageDecoderAdaLayerNormZero(nn.Module):
class GlmImageAdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, dim: int) -> None:
super().__init__()
@@ -71,11 +71,11 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
def forward(
self, hidden_states: torch.Tensor, glyph_hidden_states: torch.Tensor, temb: torch.Tensor
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = hidden_states.dtype
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_glyph_hidden_states = self.norm_context(glyph_hidden_states).to(dtype=dtype)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
emb = self.linear(temb)
(
@@ -94,7 +94,7 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
) = emb.chunk(12, dim=1)
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
glyph_hidden_states = norm_glyph_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
return (
hidden_states,
@@ -102,7 +102,7 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
shift_mlp,
scale_mlp,
gate_mlp,
glyph_hidden_states,
encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
@@ -110,17 +110,17 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
)
class GlmImageDecoderAttenProcessorState(Enum):
class GlmImageAttenProcessorState(Enum):
ImageGen = "ImageGen"
ImageEditWriteKV = "ImageEditWriteKV"
ImageEditReadKV = "ImageEditReadKV"
ImageEditDontReadKV = "ImageEditNoReadKV"
class GlmImageDecoderAttnProcessor:
class GlmImageAttnProcessor:
"""
Processor for implementing scaled dot-product attention for the GlmImageDecoder model. It applies a rotary
embedding on query and key vectors, but does not include spatial normalization.
Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
@@ -128,10 +128,8 @@ class GlmImageDecoderAttnProcessor:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"GlmImageDecoderAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
self.processor_state = GlmImageDecoderAttenProcessorState.ImageGen
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
self.processor_state = GlmImageAttenProcessorState.ImageGen
self.k_cache = None
self.v_cache = None
@@ -175,10 +173,10 @@ class GlmImageDecoderAttnProcessor:
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
if self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditWriteKV:
if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV:
self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2)
self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2)
elif self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditReadKV:
elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV:
key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key
value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value
@@ -210,7 +208,7 @@ class GlmImageDecoderAttnProcessor:
@maybe_allow_in_graph
class GlmImageDecoderTransformerBlock(nn.Module):
class GlmImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int = 2560,
@@ -221,7 +219,7 @@ class GlmImageDecoderTransformerBlock(nn.Module):
super().__init__()
# 1. Attention
self.norm1 = GlmImageDecoderAdaLayerNormZero(time_embed_dim, dim)
self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
@@ -231,7 +229,7 @@ class GlmImageDecoderTransformerBlock(nn.Module):
qk_norm="layer_norm",
elementwise_affine=False,
eps=1e-5,
processor=GlmImageDecoderAttnProcessor(),
processor=GlmImageAttnProcessor(),
)
# 2. Feedforward
@@ -242,7 +240,7 @@ class GlmImageDecoderTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
glyph_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
@@ -257,42 +255,42 @@ class GlmImageDecoderTransformerBlock(nn.Module):
shift_mlp,
scale_mlp,
gate_mlp,
norm_glyph_hidden_states,
norm_encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.norm1(hidden_states, glyph_hidden_states, temb)
) = self.norm1(hidden_states, encoder_hidden_states, temb)
# 2. Attention
if attention_kwargs is None:
attention_kwargs = {}
attn_hidden_states, attn_glyph_hidden_states = self.attn1(
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_glyph_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
glyph_hidden_states = glyph_hidden_states + attn_glyph_hidden_states * c_gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
# 3. Feedforward
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
norm_glyph_hidden_states = self.norm2_context(glyph_hidden_states) * (
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
1 + c_scale_mlp.unsqueeze(1)
) + c_shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states)
ff_output_context = self.ff(norm_glyph_hidden_states)
ff_output_context = self.ff(norm_encoder_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
glyph_hidden_states = glyph_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
return hidden_states, glyph_hidden_states
return hidden_states, encoder_hidden_states
class GlmImageDecoderRotaryPosEmbed(nn.Module):
class GlmImageRotaryPosEmbed(nn.Module):
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
super().__init__()
@@ -331,10 +329,10 @@ class GlmImageDecoderRotaryPosEmbed(nn.Module):
return (freqs.cos(), freqs.sin())
class GlmImageDecoderAdaLayerNormContinuous(nn.Module):
class GlmImageAdaLayerNormContinuous(nn.Module):
"""
GlmImageDecoder-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation**
before the Linear on conditioning embedding.
GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
Linear on conditioning embedding.
"""
def __init__(
@@ -363,7 +361,7 @@ class GlmImageDecoderAdaLayerNormContinuous(nn.Module):
return x
class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
patch_size (`int`, defaults to `2`):
@@ -397,9 +395,9 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
_supports_gradient_checkpointing = True
_no_split_modules = [
"GlmImageDecoderTransformerBlock",
"GlmImageDecoderImageProjector",
"GlmImageDecoderImageProjector",
"GlmImageTransformerBlock",
"GlmImageImageProjector",
"GlmImageImageProjector",
]
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
@@ -412,35 +410,30 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
text_embed_dim: int = 4096,
glyph_embed_dim: int = 1472,
text_embed_dim: int = 1472,
time_embed_dim: int = 512,
condition_dim: int = 256,
pos_embed_max_size: int = 128,
sample_size: int = 128,
prior_vq_quantizer_codebook_size: int = 16384,
):
super().__init__()
# GlmImageDecoder uses 2 additional SDXL-like conditions - target_size, crop_coords
# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
pooled_projection_dim = 2 * 2 * condition_dim
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels
# 1. RoPE
self.rope = GlmImageDecoderRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageDecoderImageProjector(in_channels, inner_dim, patch_size)
# 这次没有,未来可能有text_projector
# self.text_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu")
self.glyph_projector = FeedForward(glyph_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
self.text_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
self.time_condition_embed = GlmImageDecoderCombinedTimestepSizeEmbeddings(
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=pooled_projection_dim,
@@ -450,13 +443,13 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
GlmImageDecoderTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
for _ in range(num_layers)
]
)
# 4. Output projection
self.norm_out = GlmImageDecoderAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False
@@ -464,11 +457,10 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
def forward(
self,
hidden_states: torch.Tensor,
glyph_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
prior_token_id: torch.Tensor,
prior_token_drop: torch.Tensor,
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -490,7 +482,7 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
post_patch_width = width // p
hidden_states = self.image_projector(hidden_states)
glyph_hidden_states = self.glyph_projector(glyph_hidden_states)
encoder_hidden_states = self.text_projector(encoder_hidden_states)
prior_embedding = self.prior_token_embedding(prior_token_id)
prior_embedding[prior_token_drop] *= 0.0
prior_hidden_states = self.prior_projector(prior_embedding)
@@ -503,19 +495,19 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
# 3. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, glyph_hidden_states = self._gradient_checkpointing_func(
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
glyph_hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
)
else:
hidden_states, glyph_hidden_states = block(
hidden_states, encoder_hidden_states = block(
hidden_states,
glyph_hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
@@ -534,7 +526,7 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
return (output,)
return Transformer2DModelOutput(sample=output)
def set_attention_processors_state(self, state: GlmImageDecoderAttenProcessorState):
def set_attention_processors_state(self, state: GlmImageAttenProcessorState):
for block in self.transformer_blocks:
block.attn1.processor.processor_state = state
@@ -542,212 +534,3 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
for block in self.transformer_blocks:
block.attn1.processor.k_cache = None
block.attn1.processor.v_cache = None
def repeat_attention_processors_cache(self, repeats: int):
for block in self.transformer_blocks:
if block.attn1.processor.k_cache is None or block.attn1.processor.v_cache is None:
continue
block.attn1.processor.k_cache = torch.repeat_interleave(block.attn1.processor.k_cache, repeats, dim=2)
block.attn1.processor.v_cache = torch.repeat_interleave(block.attn1.processor.v_cache, repeats, dim=2)
if __name__ == "__main__":
def swap_scale_shift(weight, dim):
"""
Swap the scale and shift components in the weight tensor.
Args:
weight (torch.Tensor): The original weight tensor.
dim (int): The dimension along which to split.
Returns:
torch.Tensor: The modified weight tensor with scale and shift swapped.
"""
shift, scale = weight.chunk(2, dim=dim)
new_weight = torch.cat([scale, shift], dim=dim)
return new_weight
def convert_megatron_transformer_checkpoint_to_diffusers(
ckpt_path: str,
num_layers: int,
num_heads: int,
hidden_size: int,
):
"""
Convert a Megatron Transformer checkpoint to Diffusers format.
Args:
ckpt_path (str): Path to the Megatron Transformer checkpoint.
num_layers (int): Number of Transformer layers.
num_heads (int): Number of attention heads.
hidden_size (int): Hidden size of the Transformer.
Returns:
dict: The converted state dictionary compatible with Diffusers.
"""
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
mega = ckpt["model"]
used_keys = set()
def get_mega(key):
used_keys.add(key)
return mega[key]
new_state_dict = {}
# Patch Embedding
new_state_dict["image_projector.proj.weight"] = get_mega("encoder_expand_linear.weight").reshape(
hidden_size, 64
)
new_state_dict["image_projector.proj.bias"] = get_mega("encoder_expand_linear.bias")
new_state_dict["glyph_projector.net.0.proj.weight"] = get_mega("glyph_projector.linear_fc1.weight")
new_state_dict["glyph_projector.net.0.proj.bias"] = get_mega("glyph_projector.linear_fc1.bias")
new_state_dict["glyph_projector.net.2.weight"] = get_mega("glyph_projector.linear_fc2.weight")
new_state_dict["glyph_projector.net.2.bias"] = get_mega("glyph_projector.linear_fc2.bias")
new_state_dict["prior_token_embedding.weight"] = get_mega("xomni_token_id_embedding.weight")
new_state_dict["prior_projector.net.0.proj.weight"] = get_mega("prior_condition_embedding.0.weight")
new_state_dict["prior_projector.net.0.proj.bias"] = get_mega("prior_condition_embedding.0.bias")
new_state_dict["prior_projector.net.2.weight"] = get_mega("prior_condition_embedding.2.weight")
new_state_dict["prior_projector.net.2.bias"] = get_mega("prior_condition_embedding.2.bias")
# Time Condition Embedding
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = get_mega(
"time_embedding.time_embed.0.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = get_mega(
"time_embedding.time_embed.0.bias"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = get_mega(
"time_embedding.time_embed.2.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = get_mega(
"time_embedding.time_embed.2.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = get_mega(
"label_embedding.label_embed.0.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = get_mega(
"label_embedding.label_embed.0.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = get_mega(
"label_embedding.label_embed.2.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = get_mega(
"label_embedding.label_embed.2.bias"
)
# Convert each Transformer layer
from tqdm import tqdm
for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
block_prefix = f"transformer_blocks.{i}."
# AdaLayerNorm
new_state_dict[block_prefix + "norm1.linear.weight"] = get_mega(f"decoder.layers.{i}.adaln.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = get_mega(f"decoder.layers.{i}.adaln.bias")
qkv_weight = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.weight")
qkv_bias = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.bias")
# Reshape to match SAT logic
qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
# Assign to Diffusers keys
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
# Attention Output
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = get_mega(
f"decoder.layers.{i}.self_attention.linear_proj.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = get_mega(
f"decoder.layers.{i}.self_attention.linear_proj.bias"
)
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = get_mega(
f"decoder.layers.{i}.mlp.linear_fc1.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc1.bias")
new_state_dict[block_prefix + "ff.net.2.weight"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.weight")
new_state_dict[block_prefix + "ff.net.2.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.bias")
# Final Layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(get_mega("adaln_final.weight"), dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(get_mega("adaln_final.bias"), dim=0)
new_state_dict["proj_out.weight"] = get_mega("output_projector.weight")
new_state_dict["proj_out.bias"] = get_mega("output_projector.bias")
# Check for unused keys
all_keys = set(mega.keys())
unused_keys = all_keys - used_keys
if unused_keys:
print(f"\n[WARNING] The following {len(unused_keys)} keys in mega were NOT used:")
for key in sorted(unused_keys):
print(f" - {key}")
raise ValueError(
f"Found {len(unused_keys)} unused keys in Megatron checkpoint. Please update the conversion script to handle these keys."
)
else:
print(f"\n[INFO] All {len(all_keys)} keys in mega were successfully used.")
return new_state_dict
transformer = GlmImageDecoderTransformer2DModel(
patch_size=2,
in_channels=16,
num_layers=30,
attention_head_dim=128,
num_attention_heads=32,
out_channels=16,
text_embed_dim=4096,
time_embed_dim=512,
glyph_embed_dim=1472,
condition_dim=256,
pos_embed_max_size=128,
).to(torch.bfloat16)
converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
ckpt_path="/workspace/ckpt/tjy/Glm-train-dev/examples/cogview/ckpts/merge/1+6_0.5+0.5/iter_0000000/mp_rank_00/model_optim_rng.pt",
num_layers=30,
num_heads=32,
hidden_size=4096,
)
transformer.load_state_dict(converted_transformer_state_dict)
transformer.cuda()
latent = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/latent.pt").to(torch.bfloat16)
latent = rearrange(latent, "(b h w) (c p q) -> b c (h p) (w q)", b=8, h=72, w=54, p=2, q=2)
glyph_hidden_states = torch.load(
"/workspace/ckpt/tjy/glm-train-dev/examples/cogview/glyph_condition_embedding.pt"
).to(torch.bfloat16)
glyph_hidden_states = rearrange(glyph_hidden_states, "(b n) c -> b n c", b=8, n=2)
prior_token_id = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_token_id.pt")
prior_token_drop = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_drop.pt")
prior_token_id = rearrange(prior_token_id, "(b n) -> b n", b=8)
prior_token_drop = rearrange(prior_token_drop, "(b n)-> b n", b=8)
with torch.no_grad():
output = transformer(
hidden_states=latent,
glyph_hidden_states=glyph_hidden_states,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop,
timestep=torch.tensor([999.0] * 8).cuda(),
original_size=torch.tensor([[144, 108]] * 8).cuda(),
target_size=torch.tensor([[144, 108]] * 8).cuda(),
crop_coords=torch.tensor([[0, 0]] * 8).cuda(),
)

View File

@@ -52,7 +52,7 @@ from .flux import (
FluxKontextPipeline,
FluxPipeline,
)
from .glm_image import GlmImageDecoderPipeline
from .glm_image import GlmImagePipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -168,7 +168,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("glm_image", GlmImageDecoderPipeline),
("glm_image", GlmImagePipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),

View File

@@ -12,7 +12,7 @@ from ...utils import (
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["GlmImageDecoderPipelineOutput"]}
_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_glm_image"] = ["GlmImageDecoderPipeline"]
_import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -30,7 +30,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_glm_image import GlmImageDecoderPipeline
from .pipeline_glm_image import GlmImagePipeline
else:
import sys

View File

@@ -25,13 +25,13 @@ from transformers import AutoTokenizer, T5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, GlmImageDecoderTransformer2DModel
from ...models.transformers.transformer_glm_image import GlmImageDecoderAttenProcessorState
from ...models import AutoencoderKL, GlmImageTransformer2DModel
from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import GlmImageDecoderPipelineOutput
from .pipeline_output import GlmImagePipelineOutput
if is_torch_xla_available():
@@ -47,9 +47,9 @@ EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import GlmImageDecoderPipeline
>>> from diffusers import GlmImagePipeline
>>> pipe = GlmImageDecoderPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
>>> pipe = GlmImagePipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A photo of an astronaut riding a horse on mars"
@@ -151,7 +151,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using CogView4.
@@ -162,7 +162,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`GlmModel`]):
Frozen text-encoder. CogView4 uses [Glm-4-9b-hf](https://huggingface.co/THUDM/Glm-4-9b-hf).
Frozen text-encoder. CogView4 uses [GLM-Image](https://huggingface.co/zai-org/GLM-Image).
tokenizer (`PreTrainedTokenizer`):
Tokenizer of class
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
@@ -179,15 +179,15 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
def __init__(
self,
tokenizer: AutoTokenizer,
glyph_encoder: T5EncoderModel,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: GlmImageDecoderTransformer2DModel,
transformer: GlmImageTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, glyph_encoder=glyph_encoder, vae=vae, transformer=transformer, scheduler=scheduler
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -221,7 +221,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.glyph_encoder.dtype
dtype = dtype or self.text_encoder.dtype
glyph_texts = self.get_glyph_texts(prompt)
input_ids = self.tokenizer(
@@ -240,7 +240,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
device=device,
)
outputs = self.glyph_encoder(input_ids, attention_mask=attention_mask)
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
return glyph_embeds.to(device=device, dtype=dtype)
@@ -442,7 +442,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 2048,
) -> Union[GlmImageDecoderPipelineOutput, Tuple]:
) -> Union[GlmImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -615,7 +615,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
)
if condition_images is not None:
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditWriteKV)
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
@@ -698,7 +698,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
timestep = t.expand(latents.shape[0]) - 1
if condition_images is not None:
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditReadKV)
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV)
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
@@ -717,7 +717,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
if self.do_classifier_free_guidance:
if condition_images is not None:
self.transformer.set_attention_processors_state(
GlmImageDecoderAttenProcessorState.ImageEditDontReadKV
GlmImageAttenProcessorState.ImageEditDontReadKV
)
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
@@ -755,7 +755,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
xm.mark_step()
self._current_timestep = None
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageGen)
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen)
self.transformer.clear_attention_processors_cache()
if not output_type == "latent":
@@ -783,4 +783,4 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
if not return_dict:
return (condition_images,)
return GlmImageDecoderPipelineOutput(images=condition_images)
return GlmImagePipelineOutput(images=condition_images)

View File

@@ -8,7 +8,7 @@ from ...utils import BaseOutput
@dataclass
class GlmImageDecoderPipelineOutput(BaseOutput):
class GlmImagePipelineOutput(BaseOutput):
"""
Output class for CogView3 pipelines.

View File

@@ -967,7 +967,7 @@ class HiDreamImageTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class GlmImageDecoderTransformer2DModel(metaclass=DummyObject):
class GlmImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):