From e13fb7655235416ef1167acc37d6606b03cc4f34 Mon Sep 17 00:00:00 2001 From: zRzRzRzRzRzRzR <2448370773@qq.com> Date: Wed, 7 Jan 2026 16:55:02 +0800 Subject: [PATCH] rename --- docs/source/en/_toctree.yml | 2 +- .../en/api/models/glm_image_transformer2d.md | 8 +- docs/source/en/api/pipelines/glm_image.md | 8 +- src/diffusers/__init__.py | 8 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_glm_image.py | 319 +++--------------- src/diffusers/pipelines/auto_pipeline.py | 4 +- src/diffusers/pipelines/glm_image/__init__.py | 6 +- .../pipelines/glm_image/pipeline_glm_image.py | 36 +- .../pipelines/glm_image/pipeline_output.py | 2 +- src/diffusers/utils/dummy_pt_objects.py | 2 +- 13 files changed, 93 insertions(+), 310 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 456b52d494..8e2b9d6c04 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -354,7 +354,7 @@ - local: api/models/flux_transformer title: FluxTransformer2DModel - local: api/models/glm_image_transformer2d - title: GlmImageDecoderTransformer2DModel + title: GlmImageTransformer2DModel - local: api/models/hidream_image_transformer title: HiDreamImageTransformer2DModel - local: api/models/hunyuan_transformer2d diff --git a/docs/source/en/api/models/glm_image_transformer2d.md b/docs/source/en/api/models/glm_image_transformer2d.md index d31557c9f7..8a8b074560 100644 --- a/docs/source/en/api/models/glm_image_transformer2d.md +++ b/docs/source/en/api/models/glm_image_transformer2d.md @@ -9,10 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. --> -# GlmImageDecoderTransformer2DModel +# GlmImageTransformer2DModel -A Diffusion Transformer model for 2D data from [GlmImageDecoderTransformer2DModel]() +A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]() -## GlmImageDecoderTransformer2DModel +## GlmImageTransformer2DModel -[[autodoc]] GlmImageDecoderTransformer2DModel +[[autodoc]] GlmImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/glm_image.md b/docs/source/en/api/pipelines/glm_image.md index 24b5e14a1a..c3787cd77b 100644 --- a/docs/source/en/api/pipelines/glm_image.md +++ b/docs/source/en/api/pipelines/glm_image.md @@ -20,12 +20,12 @@ This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org). -## GlmImageDecoderPipeline +## GlmImagePipeline -[[autodoc]] GlmImageDecoderPipeline +[[autodoc]] GlmImagePipeline - all - __call__ -## GlmImageDecoderPipelineOutput +## GlmImagePipelineOutput -[[autodoc]] pipelines.cogview4.pipeline_output.GlmImageDecoderPipelineOutput +[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4b85f2662a..ceb52a7409 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -223,7 +223,7 @@ else: "FluxControlNetModel", "FluxMultiControlNetModel", "FluxTransformer2DModel", - "GlmImageDecoderTransformer2DModel", + "GlmImageTransformer2DModel", "HiDreamImageTransformer2DModel", "HunyuanDiT2DControlNetModel", "HunyuanDiT2DModel", @@ -488,7 +488,7 @@ else: "FluxKontextPipeline", "FluxPipeline", "FluxPriorReduxPipeline", - "GlmImageDecoderPipeline", + "GlmImagePipeline", "HiDreamImagePipeline", "HunyuanDiTControlNetPipeline", "HunyuanDiTPAGPipeline", @@ -971,7 +971,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: FluxControlNetModel, FluxMultiControlNetModel, FluxTransformer2DModel, - GlmImageDecoderTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DControlNetModel, HunyuanDiT2DModel, @@ -1206,7 +1206,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: FluxKontextPipeline, FluxPipeline, FluxPriorReduxPipeline, - GlmImageDecoderPipeline, + GlmImagePipeline, HiDreamImagePipeline, HunyuanDiTControlNetPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 29bf6016de..3851c1d541 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -96,7 +96,7 @@ if is_torch_available(): _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"] - _import_structure["transformers.transformer_glm_image"] = ["GlmImageDecoderTransformer2DModel"] + _import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"] _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"] @@ -204,7 +204,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: EasyAnimateTransformer3DModel, Flux2Transformer2DModel, FluxTransformer2DModel, - GlmImageDecoderTransformer2DModel, + GlmImageTransformer2DModel, HiDreamImageTransformer2DModel, HunyuanDiT2DModel, HunyuanImageTransformer2DModel, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 14fd9e2de9..1947e1e534 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1658,7 +1658,7 @@ class CogView3CombinedTimestepSizeEmbeddings(nn.Module): return conditioning -class GlmImageDecoderCombinedTimestepSizeEmbeddings(nn.Module): +class GlmImageCombinedTimestepSizeEmbeddings(nn.Module): def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256): super().__init__() diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ea051624f2..5f389979b5 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,7 +27,7 @@ if is_torch_available(): from .transformer_easyanimate import EasyAnimateTransformer3DModel from .transformer_flux import FluxTransformer2DModel from .transformer_flux2 import Flux2Transformer2DModel - from .transformer_glm_image import GlmImageDecoderTransformer2DModel + from .transformer_glm_image import GlmImageTransformer2DModel from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_glm_image.py b/src/diffusers/models/transformers/transformer_glm_image.py index 54a1c7cbdb..4c296b48aa 100644 --- a/src/diffusers/models/transformers/transformer_glm_image.py +++ b/src/diffusers/models/transformers/transformer_glm_image.py @@ -27,7 +27,7 @@ from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin -from ..embeddings import GlmImageDecoderCombinedTimestepSizeEmbeddings +from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LayerNorm, RMSNorm @@ -36,7 +36,7 @@ from ..normalization import LayerNorm, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class GlmImageDecoderImageProjector(nn.Module): +class GlmImageImageProjector(nn.Module): def __init__( self, in_channels: int = 16, @@ -62,7 +62,7 @@ class GlmImageDecoderImageProjector(nn.Module): return hidden_states -class GlmImageDecoderAdaLayerNormZero(nn.Module): +class GlmImageAdaLayerNormZero(nn.Module): def __init__(self, embedding_dim: int, dim: int) -> None: super().__init__() @@ -71,11 +71,11 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module): self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True) def forward( - self, hidden_states: torch.Tensor, glyph_hidden_states: torch.Tensor, temb: torch.Tensor + self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: dtype = hidden_states.dtype norm_hidden_states = self.norm(hidden_states).to(dtype=dtype) - norm_glyph_hidden_states = self.norm_context(glyph_hidden_states).to(dtype=dtype) + norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype) emb = self.linear(temb) ( @@ -94,7 +94,7 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module): ) = emb.chunk(12, dim=1) hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) - glyph_hidden_states = norm_glyph_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) + encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1) return ( hidden_states, @@ -102,7 +102,7 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module): shift_mlp, scale_mlp, gate_mlp, - glyph_hidden_states, + encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, @@ -110,17 +110,17 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module): ) -class GlmImageDecoderAttenProcessorState(Enum): +class GlmImageAttenProcessorState(Enum): ImageGen = "ImageGen" ImageEditWriteKV = "ImageEditWriteKV" ImageEditReadKV = "ImageEditReadKV" ImageEditDontReadKV = "ImageEditNoReadKV" -class GlmImageDecoderAttnProcessor: +class GlmImageAttnProcessor: """ - Processor for implementing scaled dot-product attention for the GlmImageDecoder model. It applies a rotary - embedding on query and key vectors, but does not include spatial normalization. + Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size, text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token. @@ -128,10 +128,8 @@ class GlmImageDecoderAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "GlmImageDecoderAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." - ) - self.processor_state = GlmImageDecoderAttenProcessorState.ImageGen + raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + self.processor_state = GlmImageAttenProcessorState.ImageGen self.k_cache = None self.v_cache = None @@ -175,10 +173,10 @@ class GlmImageDecoderAttnProcessor: key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2 ) - if self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditWriteKV: + if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV: self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2) self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2) - elif self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditReadKV: + elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV: key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value @@ -210,7 +208,7 @@ class GlmImageDecoderAttnProcessor: @maybe_allow_in_graph -class GlmImageDecoderTransformerBlock(nn.Module): +class GlmImageTransformerBlock(nn.Module): def __init__( self, dim: int = 2560, @@ -221,7 +219,7 @@ class GlmImageDecoderTransformerBlock(nn.Module): super().__init__() # 1. Attention - self.norm1 = GlmImageDecoderAdaLayerNormZero(time_embed_dim, dim) + self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, @@ -231,7 +229,7 @@ class GlmImageDecoderTransformerBlock(nn.Module): qk_norm="layer_norm", elementwise_affine=False, eps=1e-5, - processor=GlmImageDecoderAttnProcessor(), + processor=GlmImageAttnProcessor(), ) # 2. Feedforward @@ -242,7 +240,7 @@ class GlmImageDecoderTransformerBlock(nn.Module): def forward( self, hidden_states: torch.Tensor, - glyph_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[ Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]] @@ -257,42 +255,42 @@ class GlmImageDecoderTransformerBlock(nn.Module): shift_mlp, scale_mlp, gate_mlp, - norm_glyph_hidden_states, + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp, - ) = self.norm1(hidden_states, glyph_hidden_states, temb) + ) = self.norm1(hidden_states, encoder_hidden_states, temb) # 2. Attention if attention_kwargs is None: attention_kwargs = {} - attn_hidden_states, attn_glyph_hidden_states = self.attn1( + attn_hidden_states, attn_encoder_hidden_states = self.attn1( hidden_states=norm_hidden_states, - encoder_hidden_states=norm_glyph_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, attention_mask=attention_mask, **attention_kwargs, ) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) - glyph_hidden_states = glyph_hidden_states + attn_glyph_hidden_states * c_gate_msa.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) # 3. Feedforward norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) - norm_glyph_hidden_states = self.norm2_context(glyph_hidden_states) * ( + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * ( 1 + c_scale_mlp.unsqueeze(1) ) + c_shift_mlp.unsqueeze(1) ff_output = self.ff(norm_hidden_states) - ff_output_context = self.ff(norm_glyph_hidden_states) + ff_output_context = self.ff(norm_encoder_hidden_states) hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1) - glyph_hidden_states = glyph_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) + encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1) - return hidden_states, glyph_hidden_states + return hidden_states, encoder_hidden_states -class GlmImageDecoderRotaryPosEmbed(nn.Module): +class GlmImageRotaryPosEmbed(nn.Module): def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None: super().__init__() @@ -331,10 +329,10 @@ class GlmImageDecoderRotaryPosEmbed(nn.Module): return (freqs.cos(), freqs.sin()) -class GlmImageDecoderAdaLayerNormContinuous(nn.Module): +class GlmImageAdaLayerNormContinuous(nn.Module): """ - GlmImageDecoder-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** - before the Linear on conditioning embedding. + GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the + Linear on conditioning embedding. """ def __init__( @@ -363,7 +361,7 @@ class GlmImageDecoderAdaLayerNormContinuous(nn.Module): return x -class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): +class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin): r""" Args: patch_size (`int`, defaults to `2`): @@ -397,9 +395,9 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi _supports_gradient_checkpointing = True _no_split_modules = [ - "GlmImageDecoderTransformerBlock", - "GlmImageDecoderImageProjector", - "GlmImageDecoderImageProjector", + "GlmImageTransformerBlock", + "GlmImageImageProjector", + "GlmImageImageProjector", ] _skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"] @@ -412,35 +410,30 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi num_layers: int = 30, attention_head_dim: int = 40, num_attention_heads: int = 64, - text_embed_dim: int = 4096, - glyph_embed_dim: int = 1472, + text_embed_dim: int = 1472, time_embed_dim: int = 512, condition_dim: int = 256, - pos_embed_max_size: int = 128, - sample_size: int = 128, prior_vq_quantizer_codebook_size: int = 16384, ): super().__init__() - # GlmImageDecoder uses 2 additional SDXL-like conditions - target_size, crop_coords + # GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords # Each of these are sincos embeddings of shape 2 * condition_dim pooled_projection_dim = 2 * 2 * condition_dim inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels # 1. RoPE - self.rope = GlmImageDecoderRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) + self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0) # 2. Patch & Text-timestep embedding - self.image_projector = GlmImageDecoderImageProjector(in_channels, inner_dim, patch_size) - # 这次没有,未来可能有text_projector - # self.text_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu") - self.glyph_projector = FeedForward(glyph_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") + self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size) + self.text_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu") self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim) self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu") - self.time_condition_embed = GlmImageDecoderCombinedTimestepSizeEmbeddings( + self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings( embedding_dim=time_embed_dim, condition_dim=condition_dim, pooled_projection_dim=pooled_projection_dim, @@ -450,13 +443,13 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi # 3. Transformer blocks self.transformer_blocks = nn.ModuleList( [ - GlmImageDecoderTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) + GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim) for _ in range(num_layers) ] ) # 4. Output projection - self.norm_out = GlmImageDecoderAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) + self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True) self.gradient_checkpointing = False @@ -464,11 +457,10 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi def forward( self, hidden_states: torch.Tensor, - glyph_hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, prior_token_id: torch.Tensor, prior_token_drop: torch.Tensor, timestep: torch.LongTensor, - original_size: torch.Tensor, target_size: torch.Tensor, crop_coords: torch.Tensor, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -490,7 +482,7 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi post_patch_width = width // p hidden_states = self.image_projector(hidden_states) - glyph_hidden_states = self.glyph_projector(glyph_hidden_states) + encoder_hidden_states = self.text_projector(encoder_hidden_states) prior_embedding = self.prior_token_embedding(prior_token_id) prior_embedding[prior_token_drop] *= 0.0 prior_hidden_states = self.prior_projector(prior_embedding) @@ -503,19 +495,19 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi # 3. Transformer blocks for block in self.transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states, glyph_hidden_states = self._gradient_checkpointing_func( + hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( block, hidden_states, - glyph_hidden_states, + encoder_hidden_states, temb, image_rotary_emb, attention_mask, attention_kwargs, ) else: - hidden_states, glyph_hidden_states = block( + hidden_states, encoder_hidden_states = block( hidden_states, - glyph_hidden_states, + encoder_hidden_states, temb, image_rotary_emb, attention_mask, @@ -534,7 +526,7 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi return (output,) return Transformer2DModelOutput(sample=output) - def set_attention_processors_state(self, state: GlmImageDecoderAttenProcessorState): + def set_attention_processors_state(self, state: GlmImageAttenProcessorState): for block in self.transformer_blocks: block.attn1.processor.processor_state = state @@ -542,212 +534,3 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi for block in self.transformer_blocks: block.attn1.processor.k_cache = None block.attn1.processor.v_cache = None - - def repeat_attention_processors_cache(self, repeats: int): - for block in self.transformer_blocks: - if block.attn1.processor.k_cache is None or block.attn1.processor.v_cache is None: - continue - block.attn1.processor.k_cache = torch.repeat_interleave(block.attn1.processor.k_cache, repeats, dim=2) - block.attn1.processor.v_cache = torch.repeat_interleave(block.attn1.processor.v_cache, repeats, dim=2) - - -if __name__ == "__main__": - - def swap_scale_shift(weight, dim): - """ - Swap the scale and shift components in the weight tensor. - - Args: - weight (torch.Tensor): The original weight tensor. - dim (int): The dimension along which to split. - - Returns: - torch.Tensor: The modified weight tensor with scale and shift swapped. - """ - shift, scale = weight.chunk(2, dim=dim) - new_weight = torch.cat([scale, shift], dim=dim) - return new_weight - - def convert_megatron_transformer_checkpoint_to_diffusers( - ckpt_path: str, - num_layers: int, - num_heads: int, - hidden_size: int, - ): - """ - Convert a Megatron Transformer checkpoint to Diffusers format. - - Args: - ckpt_path (str): Path to the Megatron Transformer checkpoint. - num_layers (int): Number of Transformer layers. - num_heads (int): Number of attention heads. - hidden_size (int): Hidden size of the Transformer. - - Returns: - dict: The converted state dictionary compatible with Diffusers. - """ - ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) - mega = ckpt["model"] - used_keys = set() - - def get_mega(key): - used_keys.add(key) - return mega[key] - - new_state_dict = {} - - # Patch Embedding - new_state_dict["image_projector.proj.weight"] = get_mega("encoder_expand_linear.weight").reshape( - hidden_size, 64 - ) - new_state_dict["image_projector.proj.bias"] = get_mega("encoder_expand_linear.bias") - - new_state_dict["glyph_projector.net.0.proj.weight"] = get_mega("glyph_projector.linear_fc1.weight") - new_state_dict["glyph_projector.net.0.proj.bias"] = get_mega("glyph_projector.linear_fc1.bias") - new_state_dict["glyph_projector.net.2.weight"] = get_mega("glyph_projector.linear_fc2.weight") - new_state_dict["glyph_projector.net.2.bias"] = get_mega("glyph_projector.linear_fc2.bias") - - new_state_dict["prior_token_embedding.weight"] = get_mega("xomni_token_id_embedding.weight") - new_state_dict["prior_projector.net.0.proj.weight"] = get_mega("prior_condition_embedding.0.weight") - new_state_dict["prior_projector.net.0.proj.bias"] = get_mega("prior_condition_embedding.0.bias") - new_state_dict["prior_projector.net.2.weight"] = get_mega("prior_condition_embedding.2.weight") - new_state_dict["prior_projector.net.2.bias"] = get_mega("prior_condition_embedding.2.bias") - - # Time Condition Embedding - new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = get_mega( - "time_embedding.time_embed.0.weight" - ) - new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = get_mega( - "time_embedding.time_embed.0.bias" - ) - new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = get_mega( - "time_embedding.time_embed.2.weight" - ) - new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = get_mega( - "time_embedding.time_embed.2.bias" - ) - - new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = get_mega( - "label_embedding.label_embed.0.weight" - ) - new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = get_mega( - "label_embedding.label_embed.0.bias" - ) - new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = get_mega( - "label_embedding.label_embed.2.weight" - ) - new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = get_mega( - "label_embedding.label_embed.2.bias" - ) - - # Convert each Transformer layer - from tqdm import tqdm - - for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"): - block_prefix = f"transformer_blocks.{i}." - - # AdaLayerNorm - new_state_dict[block_prefix + "norm1.linear.weight"] = get_mega(f"decoder.layers.{i}.adaln.weight") - new_state_dict[block_prefix + "norm1.linear.bias"] = get_mega(f"decoder.layers.{i}.adaln.bias") - qkv_weight = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.weight") - qkv_bias = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.bias") - - # Reshape to match SAT logic - qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size) - qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size) - - qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads) - qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size) - - # Assign to Diffusers keys - q, k, v = torch.chunk(qkv_weight, 3, dim=0) - qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0) - - new_state_dict[block_prefix + "attn1.to_q.weight"] = q - new_state_dict[block_prefix + "attn1.to_q.bias"] = qb - new_state_dict[block_prefix + "attn1.to_k.weight"] = k - new_state_dict[block_prefix + "attn1.to_k.bias"] = kb - new_state_dict[block_prefix + "attn1.to_v.weight"] = v - new_state_dict[block_prefix + "attn1.to_v.bias"] = vb - - # Attention Output - new_state_dict[block_prefix + "attn1.to_out.0.weight"] = get_mega( - f"decoder.layers.{i}.self_attention.linear_proj.weight" - ) - new_state_dict[block_prefix + "attn1.to_out.0.bias"] = get_mega( - f"decoder.layers.{i}.self_attention.linear_proj.bias" - ) - - # MLP - new_state_dict[block_prefix + "ff.net.0.proj.weight"] = get_mega( - f"decoder.layers.{i}.mlp.linear_fc1.weight" - ) - new_state_dict[block_prefix + "ff.net.0.proj.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc1.bias") - new_state_dict[block_prefix + "ff.net.2.weight"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.weight") - new_state_dict[block_prefix + "ff.net.2.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.bias") - - # Final Layers - new_state_dict["norm_out.linear.weight"] = swap_scale_shift(get_mega("adaln_final.weight"), dim=0) - new_state_dict["norm_out.linear.bias"] = swap_scale_shift(get_mega("adaln_final.bias"), dim=0) - new_state_dict["proj_out.weight"] = get_mega("output_projector.weight") - new_state_dict["proj_out.bias"] = get_mega("output_projector.bias") - - # Check for unused keys - all_keys = set(mega.keys()) - unused_keys = all_keys - used_keys - if unused_keys: - print(f"\n[WARNING] The following {len(unused_keys)} keys in mega were NOT used:") - for key in sorted(unused_keys): - print(f" - {key}") - raise ValueError( - f"Found {len(unused_keys)} unused keys in Megatron checkpoint. Please update the conversion script to handle these keys." - ) - else: - print(f"\n[INFO] All {len(all_keys)} keys in mega were successfully used.") - - return new_state_dict - - transformer = GlmImageDecoderTransformer2DModel( - patch_size=2, - in_channels=16, - num_layers=30, - attention_head_dim=128, - num_attention_heads=32, - out_channels=16, - text_embed_dim=4096, - time_embed_dim=512, - glyph_embed_dim=1472, - condition_dim=256, - pos_embed_max_size=128, - ).to(torch.bfloat16) - converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers( - ckpt_path="/workspace/ckpt/tjy/Glm-train-dev/examples/cogview/ckpts/merge/1+6_0.5+0.5/iter_0000000/mp_rank_00/model_optim_rng.pt", - num_layers=30, - num_heads=32, - hidden_size=4096, - ) - transformer.load_state_dict(converted_transformer_state_dict) - transformer.cuda() - - latent = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/latent.pt").to(torch.bfloat16) - latent = rearrange(latent, "(b h w) (c p q) -> b c (h p) (w q)", b=8, h=72, w=54, p=2, q=2) - glyph_hidden_states = torch.load( - "/workspace/ckpt/tjy/glm-train-dev/examples/cogview/glyph_condition_embedding.pt" - ).to(torch.bfloat16) - glyph_hidden_states = rearrange(glyph_hidden_states, "(b n) c -> b n c", b=8, n=2) - prior_token_id = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_token_id.pt") - prior_token_drop = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_drop.pt") - prior_token_id = rearrange(prior_token_id, "(b n) -> b n", b=8) - prior_token_drop = rearrange(prior_token_drop, "(b n)-> b n", b=8) - - with torch.no_grad(): - output = transformer( - hidden_states=latent, - glyph_hidden_states=glyph_hidden_states, - prior_token_id=prior_token_id, - prior_token_drop=prior_token_drop, - timestep=torch.tensor([999.0] * 8).cuda(), - original_size=torch.tensor([[144, 108]] * 8).cuda(), - target_size=torch.tensor([[144, 108]] * 8).cuda(), - crop_coords=torch.tensor([[0, 0]] * 8).cuda(), - ) diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 28b882dfaf..6f583385de 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -52,7 +52,7 @@ from .flux import ( FluxKontextPipeline, FluxPipeline, ) -from .glm_image import GlmImageDecoderPipeline +from .glm_image import GlmImagePipeline from .hunyuandit import HunyuanDiTPipeline from .kandinsky import ( KandinskyCombinedPipeline, @@ -168,7 +168,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ("chroma", ChromaPipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), - ("glm_image", GlmImageDecoderPipeline), + ("glm_image", GlmImagePipeline), ("cogview4-control", CogView4ControlPipeline), ("qwenimage", QwenImagePipeline), ("qwenimage-controlnet", QwenImageControlNetPipeline), diff --git a/src/diffusers/pipelines/glm_image/__init__.py b/src/diffusers/pipelines/glm_image/__init__.py index 24c20c4ef5..9df31b0b17 100644 --- a/src/diffusers/pipelines/glm_image/__init__.py +++ b/src/diffusers/pipelines/glm_image/__init__.py @@ -12,7 +12,7 @@ from ...utils import ( _dummy_objects = {} _additional_imports = {} -_import_structure = {"pipeline_output": ["GlmImageDecoderPipelineOutput"]} +_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]} try: if not (is_transformers_available() and is_torch_available()): @@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_glm_image"] = ["GlmImageDecoderPipeline"] + _import_structure["pipeline_glm_image"] = ["GlmImagePipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -30,7 +30,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .pipeline_glm_image import GlmImageDecoderPipeline + from .pipeline_glm_image import GlmImagePipeline else: import sys diff --git a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py index 825952be4b..03ecb868ea 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_glm_image.py +++ b/src/diffusers/pipelines/glm_image/pipeline_glm_image.py @@ -25,13 +25,13 @@ from transformers import AutoTokenizer, T5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...loaders import CogView4LoraLoaderMixin -from ...models import AutoencoderKL, GlmImageDecoderTransformer2DModel -from ...models.transformers.transformer_glm_image import GlmImageDecoderAttenProcessorState +from ...models import AutoencoderKL, GlmImageTransformer2DModel +from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor -from .pipeline_output import GlmImageDecoderPipelineOutput +from .pipeline_output import GlmImagePipelineOutput if is_torch_xla_available(): @@ -47,9 +47,9 @@ EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch - >>> from diffusers import GlmImageDecoderPipeline + >>> from diffusers import GlmImagePipeline - >>> pipe = GlmImageDecoderPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + >>> pipe = GlmImagePipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A photo of an astronaut riding a horse on mars" @@ -151,7 +151,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): +class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin): r""" Pipeline for text-to-image generation using CogView4. @@ -162,7 +162,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`GlmModel`]): - Frozen text-encoder. CogView4 uses [Glm-4-9b-hf](https://huggingface.co/THUDM/Glm-4-9b-hf). + Frozen text-encoder. CogView4 uses [GLM-Image](https://huggingface.co/zai-org/GLM-Image). tokenizer (`PreTrainedTokenizer`): Tokenizer of class [PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer). @@ -179,15 +179,15 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): def __init__( self, tokenizer: AutoTokenizer, - glyph_encoder: T5EncoderModel, + text_encoder: T5EncoderModel, vae: AutoencoderKL, - transformer: GlmImageDecoderTransformer2DModel, + transformer: GlmImageTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( - tokenizer=tokenizer, glyph_encoder=glyph_encoder, vae=vae, transformer=transformer, scheduler=scheduler + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -221,7 +221,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device - dtype = dtype or self.glyph_encoder.dtype + dtype = dtype or self.text_encoder.dtype glyph_texts = self.get_glyph_texts(prompt) input_ids = self.tokenizer( @@ -240,7 +240,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): [input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device, ) - outputs = self.glyph_encoder(input_ids, attention_mask=attention_mask) + outputs = self.text_encoder(input_ids, attention_mask=attention_mask) glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0) return glyph_embeds.to(device=device, dtype=dtype) @@ -442,7 +442,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 2048, - ) -> Union[GlmImageDecoderPipelineOutput, Tuple]: + ) -> Union[GlmImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -615,7 +615,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ) if condition_images is not None: - self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditWriteKV) + self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV) latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.latent_channels, 1, 1) @@ -698,7 +698,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): timestep = t.expand(latents.shape[0]) - 1 if condition_images is not None: - self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditReadKV) + self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV) noise_pred_cond = self.transformer( hidden_states=latent_model_input, @@ -717,7 +717,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): if self.do_classifier_free_guidance: if condition_images is not None: self.transformer.set_attention_processors_state( - GlmImageDecoderAttenProcessorState.ImageEditDontReadKV + GlmImageAttenProcessorState.ImageEditDontReadKV ) noise_pred_uncond = self.transformer( hidden_states=latent_model_input, @@ -755,7 +755,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): xm.mark_step() self._current_timestep = None - self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageGen) + self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen) self.transformer.clear_attention_processors_cache() if not output_type == "latent": @@ -783,4 +783,4 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin): if not return_dict: return (condition_images,) - return GlmImageDecoderPipelineOutput(images=condition_images) + return GlmImagePipelineOutput(images=condition_images) diff --git a/src/diffusers/pipelines/glm_image/pipeline_output.py b/src/diffusers/pipelines/glm_image/pipeline_output.py index a506e527cb..aec5a5454e 100644 --- a/src/diffusers/pipelines/glm_image/pipeline_output.py +++ b/src/diffusers/pipelines/glm_image/pipeline_output.py @@ -8,7 +8,7 @@ from ...utils import BaseOutput @dataclass -class GlmImageDecoderPipelineOutput(BaseOutput): +class GlmImagePipelineOutput(BaseOutput): """ Output class for CogView3 pipelines. diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index b36d7f3ccf..d2355d4737 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -967,7 +967,7 @@ class HiDreamImageTransformer2DModel(metaclass=DummyObject): requires_backends(cls, ["torch"]) -class GlmImageDecoderTransformer2DModel(metaclass=DummyObject): +class GlmImageTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] def __init__(self, *args, **kwargs):