1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
zRzRzRzRzRzRzR
2026-01-07 15:31:11 +08:00
parent 98479a94c2
commit ec9a82fc3f
9 changed files with 1690 additions and 0 deletions

View File

@@ -0,0 +1,18 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# GlmImageDecoderTransformer2DModel
A Diffusion Transformer model for 2D data from [GlmImageDecoderTransformer2DModel]()
## GlmImageDecoderTransformer2DModel
[[autodoc]] GlmImageDecoderTransformer2DModel

View File

@@ -0,0 +1,31 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# GLM-Image
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org).
## GlmImageDecoderPipeline
[[autodoc]] GlmImageDecoderPipeline
- all
- __call__
## GlmImageDecoderPipelineOutput
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImageDecoderPipelineOutput

View File

@@ -96,6 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageDecoderTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -203,6 +204,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageDecoderTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,

View File

@@ -1658,6 +1658,37 @@ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
return conditioning
class GlmImageDecoderCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(
self,
timestep: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
hidden_dtype: torch.dtype,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
# (B, 2 * condition_dim)
condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
conditioning = timesteps_emb + condition_emb
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6

View File

@@ -27,6 +27,7 @@ if is_torch_available():
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageDecoderTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel

View File

@@ -0,0 +1,753 @@
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import GlmImageDecoderCombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class GlmImageDecoderImageProjector(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
hidden_states = hidden_states.reshape(
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
hidden_states = self.proj(hidden_states)
return hidden_states
class GlmImageDecoderAdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, dim: int) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
def forward(
self, hidden_states: torch.Tensor, glyph_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = hidden_states.dtype
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_glyph_hidden_states = self.norm_context(glyph_hidden_states).to(dtype=dtype)
emb = self.linear(temb)
(
shift_msa,
c_shift_msa,
scale_msa,
c_scale_msa,
gate_msa,
c_gate_msa,
shift_mlp,
c_shift_mlp,
scale_mlp,
c_scale_mlp,
gate_mlp,
c_gate_mlp,
) = emb.chunk(12, dim=1)
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
glyph_hidden_states = norm_glyph_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
return (
hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
glyph_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
)
class GlmImageDecoderAttenProcessorState(Enum):
ImageGen = "ImageGen"
ImageEditWriteKV = "ImageEditWriteKV"
ImageEditReadKV = "ImageEditReadKV"
ImageEditDontReadKV = "ImageEditNoReadKV"
class GlmImageDecoderAttnProcessor:
"""
Processor for implementing scaled dot-product attention for the GlmImageDecoder model. It applies a rotary
embedding on query and key vectors, but does not include spatial normalization.
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"GlmImageDecoderAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
self.processor_state = GlmImageDecoderAttenProcessorState.ImageGen
self.k_cache = None
self.v_cache = None
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key).to(dtype=dtype)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
if self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditWriteKV:
self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2)
self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2)
elif self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditReadKV:
key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key
value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value
# 4. Attention
if attention_mask is not None:
text_attn_mask = attention_mask
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
text_attn_mask = text_attn_mask.float().to(query.device)
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
mix_attn_mask[:, :text_seq_length] = text_attn_mask
mix_attn_mask = mix_attn_mask.unsqueeze(2)
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class GlmImageDecoderTransformerBlock(nn.Module):
def __init__(
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
) -> None:
super().__init__()
# 1. Attention
self.norm1 = GlmImageDecoderAdaLayerNormZero(time_embed_dim, dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
out_dim=dim,
bias=True,
qk_norm="layer_norm",
elementwise_affine=False,
eps=1e-5,
processor=GlmImageDecoderAttnProcessor(),
)
# 2. Feedforward
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
glyph_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
norm_glyph_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.norm1(hidden_states, glyph_hidden_states, temb)
# 2. Attention
if attention_kwargs is None:
attention_kwargs = {}
attn_hidden_states, attn_glyph_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_glyph_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
glyph_hidden_states = glyph_hidden_states + attn_glyph_hidden_states * c_gate_msa.unsqueeze(1)
# 3. Feedforward
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
norm_glyph_hidden_states = self.norm2_context(glyph_hidden_states) * (
1 + c_scale_mlp.unsqueeze(1)
) + c_shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states)
ff_output_context = self.ff(norm_glyph_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
glyph_hidden_states = glyph_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
return hidden_states, glyph_hidden_states
class GlmImageDecoderRotaryPosEmbed(nn.Module):
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
dim_h, dim_w = self.dim // 2, self.dim // 2
h_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
)
w_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
)
h_seq = torch.arange(height)
w_seq = torch.arange(width)
freqs_h = torch.outer(h_seq, h_inv_freq)
freqs_w = torch.outer(w_seq, w_inv_freq)
# Create position matrices for height and width
# [height, 1, dim//4] and [1, width, dim//4]
freqs_h = freqs_h.unsqueeze(1)
freqs_w = freqs_w.unsqueeze(0)
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
freqs_h = freqs_h.expand(height, width, -1)
freqs_w = freqs_w.expand(height, width, -1)
# Concatenate along last dimension to get [height, width, dim//2]
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
freqs = freqs.reshape(height * width, -1)
return (freqs.cos(), freqs.sin())
class GlmImageDecoderAdaLayerNormContinuous(nn.Module):
"""
GlmImageDecoder-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation**
before the Linear on conditioning embedding.
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_type: str = "layer_norm",
):
super().__init__()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# *** NO SiLU here ***
emb = self.linear(conditioning_embedding.to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `40`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `64`):
The number of heads to use for multi-head attention.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
condition_dim (`int`, defaults to `256`):
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
crop_coords).
pos_embed_max_size (`int`, defaults to `128`):
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
patch_size => 128 * 8 * 2 => 2048`.
sample_size (`int`, defaults to `128`):
The base resolution of input latents. If height/width is not provided during generation, this value is used
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
"""
_supports_gradient_checkpointing = True
_no_split_modules = [
"GlmImageDecoderTransformerBlock",
"GlmImageDecoderImageProjector",
"GlmImageDecoderImageProjector",
]
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
text_embed_dim: int = 4096,
glyph_embed_dim: int = 1472,
time_embed_dim: int = 512,
condition_dim: int = 256,
pos_embed_max_size: int = 128,
sample_size: int = 128,
prior_vq_quantizer_codebook_size: int = 16384,
):
super().__init__()
# GlmImageDecoder uses 2 additional SDXL-like conditions - target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
pooled_projection_dim = 2 * 2 * condition_dim
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels
# 1. RoPE
self.rope = GlmImageDecoderRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageDecoderImageProjector(in_channels, inner_dim, patch_size)
# 这次没有未来可能有text_projector
# self.text_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu")
self.glyph_projector = FeedForward(glyph_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
self.time_condition_embed = GlmImageDecoderCombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=pooled_projection_dim,
timesteps_dim=time_embed_dim,
)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
GlmImageDecoderTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
for _ in range(num_layers)
]
)
# 4. Output projection
self.norm_out = GlmImageDecoderAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
glyph_hidden_states: torch.Tensor,
prior_token_id: torch.Tensor,
prior_token_drop: torch.Tensor,
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
if image_rotary_emb is None:
image_rotary_emb = self.rope(hidden_states)
# 2. Patch & Timestep embeddings
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
hidden_states = self.image_projector(hidden_states)
glyph_hidden_states = self.glyph_projector(glyph_hidden_states)
prior_embedding = self.prior_token_embedding(prior_token_id)
prior_embedding[prior_token_drop] *= 0.0
prior_hidden_states = self.prior_projector(prior_embedding)
hidden_states = hidden_states + prior_hidden_states
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
temb = F.silu(temb)
# 3. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, glyph_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
glyph_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
)
else:
hidden_states, glyph_hidden_states = block(
hidden_states,
glyph_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
)
# 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def set_attention_processors_state(self, state: GlmImageDecoderAttenProcessorState):
for block in self.transformer_blocks:
block.attn1.processor.processor_state = state
def clear_attention_processors_cache(self):
for block in self.transformer_blocks:
block.attn1.processor.k_cache = None
block.attn1.processor.v_cache = None
def repeat_attention_processors_cache(self, repeats: int):
for block in self.transformer_blocks:
if block.attn1.processor.k_cache is None or block.attn1.processor.v_cache is None:
continue
block.attn1.processor.k_cache = torch.repeat_interleave(block.attn1.processor.k_cache, repeats, dim=2)
block.attn1.processor.v_cache = torch.repeat_interleave(block.attn1.processor.v_cache, repeats, dim=2)
if __name__ == "__main__":
def swap_scale_shift(weight, dim):
"""
Swap the scale and shift components in the weight tensor.
Args:
weight (torch.Tensor): The original weight tensor.
dim (int): The dimension along which to split.
Returns:
torch.Tensor: The modified weight tensor with scale and shift swapped.
"""
shift, scale = weight.chunk(2, dim=dim)
new_weight = torch.cat([scale, shift], dim=dim)
return new_weight
def convert_megatron_transformer_checkpoint_to_diffusers(
ckpt_path: str,
num_layers: int,
num_heads: int,
hidden_size: int,
):
"""
Convert a Megatron Transformer checkpoint to Diffusers format.
Args:
ckpt_path (str): Path to the Megatron Transformer checkpoint.
num_layers (int): Number of Transformer layers.
num_heads (int): Number of attention heads.
hidden_size (int): Hidden size of the Transformer.
Returns:
dict: The converted state dictionary compatible with Diffusers.
"""
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
mega = ckpt["model"]
used_keys = set()
def get_mega(key):
used_keys.add(key)
return mega[key]
new_state_dict = {}
# Patch Embedding
new_state_dict["image_projector.proj.weight"] = get_mega("encoder_expand_linear.weight").reshape(
hidden_size, 64
)
new_state_dict["image_projector.proj.bias"] = get_mega("encoder_expand_linear.bias")
new_state_dict["glyph_projector.net.0.proj.weight"] = get_mega("glyph_projector.linear_fc1.weight")
new_state_dict["glyph_projector.net.0.proj.bias"] = get_mega("glyph_projector.linear_fc1.bias")
new_state_dict["glyph_projector.net.2.weight"] = get_mega("glyph_projector.linear_fc2.weight")
new_state_dict["glyph_projector.net.2.bias"] = get_mega("glyph_projector.linear_fc2.bias")
new_state_dict["prior_token_embedding.weight"] = get_mega("xomni_token_id_embedding.weight")
new_state_dict["prior_projector.net.0.proj.weight"] = get_mega("prior_condition_embedding.0.weight")
new_state_dict["prior_projector.net.0.proj.bias"] = get_mega("prior_condition_embedding.0.bias")
new_state_dict["prior_projector.net.2.weight"] = get_mega("prior_condition_embedding.2.weight")
new_state_dict["prior_projector.net.2.bias"] = get_mega("prior_condition_embedding.2.bias")
# Time Condition Embedding
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = get_mega(
"time_embedding.time_embed.0.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = get_mega(
"time_embedding.time_embed.0.bias"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = get_mega(
"time_embedding.time_embed.2.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = get_mega(
"time_embedding.time_embed.2.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = get_mega(
"label_embedding.label_embed.0.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = get_mega(
"label_embedding.label_embed.0.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = get_mega(
"label_embedding.label_embed.2.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = get_mega(
"label_embedding.label_embed.2.bias"
)
# Convert each Transformer layer
from tqdm import tqdm
for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
block_prefix = f"transformer_blocks.{i}."
# AdaLayerNorm
new_state_dict[block_prefix + "norm1.linear.weight"] = get_mega(f"decoder.layers.{i}.adaln.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = get_mega(f"decoder.layers.{i}.adaln.bias")
qkv_weight = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.weight")
qkv_bias = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.bias")
# Reshape to match SAT logic
qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
# Assign to Diffusers keys
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
# Attention Output
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = get_mega(
f"decoder.layers.{i}.self_attention.linear_proj.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = get_mega(
f"decoder.layers.{i}.self_attention.linear_proj.bias"
)
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = get_mega(
f"decoder.layers.{i}.mlp.linear_fc1.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc1.bias")
new_state_dict[block_prefix + "ff.net.2.weight"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.weight")
new_state_dict[block_prefix + "ff.net.2.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.bias")
# Final Layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(get_mega("adaln_final.weight"), dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(get_mega("adaln_final.bias"), dim=0)
new_state_dict["proj_out.weight"] = get_mega("output_projector.weight")
new_state_dict["proj_out.bias"] = get_mega("output_projector.bias")
# Check for unused keys
all_keys = set(mega.keys())
unused_keys = all_keys - used_keys
if unused_keys:
print(f"\n[WARNING] The following {len(unused_keys)} keys in mega were NOT used:")
for key in sorted(unused_keys):
print(f" - {key}")
raise ValueError(
f"Found {len(unused_keys)} unused keys in Megatron checkpoint. Please update the conversion script to handle these keys."
)
else:
print(f"\n[INFO] All {len(all_keys)} keys in mega were successfully used.")
return new_state_dict
transformer = GlmImageDecoderTransformer2DModel(
patch_size=2,
in_channels=16,
num_layers=30,
attention_head_dim=128,
num_attention_heads=32,
out_channels=16,
text_embed_dim=4096,
time_embed_dim=512,
glyph_embed_dim=1472,
condition_dim=256,
pos_embed_max_size=128,
).to(torch.bfloat16)
converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
ckpt_path="/workspace/ckpt/tjy/Glm-train-dev/examples/cogview/ckpts/merge/1+6_0.5+0.5/iter_0000000/mp_rank_00/model_optim_rng.pt",
num_layers=30,
num_heads=32,
hidden_size=4096,
)
transformer.load_state_dict(converted_transformer_state_dict)
transformer.cuda()
latent = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/latent.pt").to(torch.bfloat16)
latent = rearrange(latent, "(b h w) (c p q) -> b c (h p) (w q)", b=8, h=72, w=54, p=2, q=2)
glyph_hidden_states = torch.load(
"/workspace/ckpt/tjy/glm-train-dev/examples/cogview/glyph_condition_embedding.pt"
).to(torch.bfloat16)
glyph_hidden_states = rearrange(glyph_hidden_states, "(b n) c -> b n c", b=8, n=2)
prior_token_id = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_token_id.pt")
prior_token_drop = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_drop.pt")
prior_token_id = rearrange(prior_token_id, "(b n) -> b n", b=8)
prior_token_drop = rearrange(prior_token_drop, "(b n)-> b n", b=8)
with torch.no_grad():
output = transformer(
hidden_states=latent,
glyph_hidden_states=glyph_hidden_states,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop,
timestep=torch.tensor([999.0] * 8).cuda(),
original_size=torch.tensor([[144, 108]] * 8).cuda(),
target_size=torch.tensor([[144, 108]] * 8).cuda(),
crop_coords=torch.tensor([[0, 0]] * 8).cuda(),
)

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["GlmImageDecoderPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_glm_image"] = ["GlmImageDecoderPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_glm_image import GlmImageDecoderPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,786 @@
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from transformers import AutoTokenizer, T5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, GlmImageDecoderTransformer2DModel
from ...models.transformers.transformer_glm_image import GlmImageDecoderAttenProcessorState
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import GlmImageDecoderPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import GlmImageDecoderPipeline
>>> pipe = GlmImageDecoderPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
>>> image.save("output.png")
```
"""
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
base_shift: float = 0.25,
max_shift: float = 0.75,
) -> float:
m = (image_seq_len / base_seq_len) ** 0.5
mu = m * max_shift + base_shift
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if timesteps is not None and sigmas is not None:
if not accepts_timesteps and not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is not None and sigmas is None:
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is None and sigmas is not None:
if not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using CogView4.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`GlmModel`]):
Frozen text-encoder. CogView4 uses [Glm-4-9b-hf](https://huggingface.co/THUDM/Glm-4-9b-hf).
tokenizer (`PreTrainedTokenizer`):
Tokenizer of class
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
transformer ([`CogView4Transformer2DModel`]):
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
_optional_components = []
model_cpu_offload_seq = "transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
tokenizer: AutoTokenizer,
glyph_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: GlmImageDecoderTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, glyph_encoder=glyph_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer")
and self.transformer is not None
and hasattr(self.transformer.config, "sample_size")
else 128
)
def get_glyph_texts(
self,
prompt,
):
prompt = prompt[0] if isinstance(prompt, list) else prompt
ocr_texts = (
re.findall(r"'([^']*)'", prompt)
+ re.findall(r"“([^“”]*)”", prompt)
+ re.findall(r'"([^"]*)"', prompt)
+ re.findall(r"「([^「」]*)」", prompt)
)
return ocr_texts
def _get_glyph_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.glyph_encoder.dtype
glyph_texts = self.get_glyph_texts(prompt)
input_ids = self.tokenizer(
glyph_texts if len(glyph_texts) > 0 else [""],
max_length=max_sequence_length,
truncation=True,
).input_ids
input_ids = [
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
]
max_length = max(len(input_ids_) for input_ids_ in input_ids)
attention_mask = torch.tensor(
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
)
input_ids = torch.tensor(
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
device=device,
)
outputs = self.glyph_encoder(input_ids, attention_mask=attention_mask)
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
return glyph_embeds.to(device=device, dtype=dtype)
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 2048,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
max_sequence_length (`int`, defaults to `2048`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
seq_len = prompt_embeds.size(1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
seq_len = negative_prompt_embeds.size(1)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if (
height is not None
and height % (self.vae_scale_factor * self.transformer.config.patch_size) != 0
or width is not None
and width % (self.transformer.config.patch_size) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prior_token_id: Optional[torch.LongTensor] = None,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
condition_images: Optional[
Union[
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
]
] = None,
condition_images_prior_token_id: Optional[Union[torch.LongTensor, List[torch.LongTensor]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.5,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 2048,
) -> Union[GlmImageDecoderPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. If not provided, it is set to 2048.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. If not provided it is set to 2048.
num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
original_size (`Tuple[int]`, *optional*, defaults to (2048, 2048)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `224`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
Examples:
Returns:
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] or `tuple`:
[`~pipelines.cogview4.pipeline_CogView4.CogView4PipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
assert batch_size == 1, "batch_size must be 1"
device = self._execution_device
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
dtype=self.dtype,
)
# 4. process images
if condition_images is not None and not isinstance(condition_images, list):
condition_images = [condition_images]
condition_images_prior_token_id = [condition_images_prior_token_id]
assert condition_images is None or (len(condition_images) == len(condition_images_prior_token_id)), (
"image and image_prior_token_id must be the same length"
)
if condition_images is not None:
preprocessed_condition_images = []
for img in condition_images:
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
image_height = (image_height // multiple_of) * multiple_of
image_width = (image_width // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
preprocessed_condition_images.append(img)
height = height or image_height
width = width or image_width
condition_images = preprocessed_condition_images
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 5. Prepare latents and (optional) condition_images kv cache
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=latent_channels,
height=height,
width=width,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
if condition_images is not None:
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditWriteKV)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(self.vae.device, self.vae.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(self.vae.device, self.vae.dtype)
)
empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...]
for condition_image, condition_image_prior_token_id in zip(
condition_images, condition_images_prior_token_id
):
condition_image = condition_image.to(device=device, dtype=self.vae.dtype)
condition_latent = retrieve_latents(
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
)
condition_latent = (condition_latent - latents_mean) / latents_std
_ = self.transformer(
hidden_states=condition_latent,
glyph_hidden_states=empty_glyph_hiddens,
prior_token_id=condition_image_prior_token_id,
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
timestep=torch.zeros((1,), device=device),
original_size=torch.tensor([condition_image.shape[-2:]], device=device),
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
crop_coords=torch.zeros((1, 2), device=device),
attention_kwargs=attention_kwargs,
)
# 6. Prepare additional timestep conditions
original_size = original_size or (height, width)
target_size = (height, width)
original_size = torch.tensor([original_size], dtype=prompt_embeds.dtype, device=device)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
original_size = original_size.repeat(batch_size * num_images_per_prompt, 1)
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
# Prepare timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
timesteps = (
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1]
if timesteps is None
else np.array(timesteps)
)
timesteps = timesteps.astype(np.int64).astype(np.float32)
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
)
self._num_timesteps = len(timesteps)
# 7. Denoising loop
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool)
prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]) - 1
if condition_images is not None:
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditReadKV)
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
glyph_hidden_states=prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_cond,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0].float()
# perform guidance
if self.do_classifier_free_guidance:
if condition_images is not None:
self.transformer.set_attention_processors_state(
GlmImageDecoderAttenProcessorState.ImageEditDontReadKV
)
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
glyph_hidden_states=negative_prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_uncond,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0].float()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageGen)
self.transformer.clear_attention_processors_cache()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std + latents_mean
condition_images = self.vae.decode(latents, return_dict=False, generator=generator)[0]
else:
condition_images = latents
condition_images = self.image_processor.postprocess(condition_images, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (condition_images,)
return GlmImageDecoderPipelineOutput(images=condition_images)

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class GlmImageDecoderPipelineOutput(BaseOutput):
"""
Output class for CogView3 pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]