mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
rename
This commit is contained in:
@@ -354,7 +354,7 @@
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/glm_image_transformer2d
|
||||
title: GlmImageDecoderTransformer2DModel
|
||||
title: GlmImageTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
|
||||
@@ -9,10 +9,10 @@ Unless required by applicable law or agreed to in writing, software distributed
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# GlmImageDecoderTransformer2DModel
|
||||
# GlmImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [GlmImageDecoderTransformer2DModel]()
|
||||
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]()
|
||||
|
||||
## GlmImageDecoderTransformer2DModel
|
||||
## GlmImageTransformer2DModel
|
||||
|
||||
[[autodoc]] GlmImageDecoderTransformer2DModel
|
||||
[[autodoc]] GlmImageTransformer2DModel
|
||||
|
||||
@@ -20,12 +20,12 @@
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org).
|
||||
|
||||
## GlmImageDecoderPipeline
|
||||
## GlmImagePipeline
|
||||
|
||||
[[autodoc]] GlmImageDecoderPipeline
|
||||
[[autodoc]] GlmImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## GlmImageDecoderPipelineOutput
|
||||
## GlmImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImageDecoderPipelineOutput
|
||||
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput
|
||||
|
||||
@@ -223,7 +223,7 @@ else:
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"GlmImageDecoderTransformer2DModel",
|
||||
"GlmImageTransformer2DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -488,7 +488,7 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"GlmImageDecoderPipeline",
|
||||
"GlmImagePipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -971,7 +971,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageDecoderTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -1206,7 +1206,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
GlmImageDecoderPipeline,
|
||||
GlmImagePipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
|
||||
@@ -96,7 +96,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageDecoderTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
@@ -204,7 +204,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageDecoderTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
|
||||
@@ -1658,7 +1658,7 @@ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
return conditioning
|
||||
|
||||
|
||||
class GlmImageDecoderCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ if is_torch_available():
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_glm_image import GlmImageDecoderTransformer2DModel
|
||||
from .transformer_glm_image import GlmImageTransformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import GlmImageDecoderCombinedTimestepSizeEmbeddings
|
||||
from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LayerNorm, RMSNorm
|
||||
@@ -36,7 +36,7 @@ from ..normalization import LayerNorm, RMSNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GlmImageDecoderImageProjector(nn.Module):
|
||||
class GlmImageImageProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
@@ -62,7 +62,7 @@ class GlmImageDecoderImageProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlmImageDecoderAdaLayerNormZero(nn.Module):
|
||||
class GlmImageAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -71,11 +71,11 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
|
||||
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, glyph_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = hidden_states.dtype
|
||||
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
|
||||
norm_glyph_hidden_states = self.norm_context(glyph_hidden_states).to(dtype=dtype)
|
||||
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
|
||||
|
||||
emb = self.linear(temb)
|
||||
(
|
||||
@@ -94,7 +94,7 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
|
||||
) = emb.chunk(12, dim=1)
|
||||
|
||||
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
glyph_hidden_states = norm_glyph_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
|
||||
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
@@ -102,7 +102,7 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
glyph_hidden_states,
|
||||
encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
@@ -110,17 +110,17 @@ class GlmImageDecoderAdaLayerNormZero(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class GlmImageDecoderAttenProcessorState(Enum):
|
||||
class GlmImageAttenProcessorState(Enum):
|
||||
ImageGen = "ImageGen"
|
||||
ImageEditWriteKV = "ImageEditWriteKV"
|
||||
ImageEditReadKV = "ImageEditReadKV"
|
||||
ImageEditDontReadKV = "ImageEditNoReadKV"
|
||||
|
||||
|
||||
class GlmImageDecoderAttnProcessor:
|
||||
class GlmImageAttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention for the GlmImageDecoder model. It applies a rotary
|
||||
embedding on query and key vectors, but does not include spatial normalization.
|
||||
Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
|
||||
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
|
||||
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
|
||||
@@ -128,10 +128,8 @@ class GlmImageDecoderAttnProcessor:
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"GlmImageDecoderAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
self.processor_state = GlmImageDecoderAttenProcessorState.ImageGen
|
||||
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
self.processor_state = GlmImageAttenProcessorState.ImageGen
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
|
||||
@@ -175,10 +173,10 @@ class GlmImageDecoderAttnProcessor:
|
||||
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
|
||||
if self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditWriteKV:
|
||||
if self.processor_state == GlmImageAttenProcessorState.ImageEditWriteKV:
|
||||
self.k_cache = key if self.k_cache is None else torch.cat([self.k_cache, key], dim=2)
|
||||
self.v_cache = value if self.v_cache is None else torch.cat([self.v_cache, value], dim=2)
|
||||
elif self.processor_state == GlmImageDecoderAttenProcessorState.ImageEditReadKV:
|
||||
elif self.processor_state == GlmImageAttenProcessorState.ImageEditReadKV:
|
||||
key = torch.cat([self.k_cache, key], dim=2) if self.k_cache is not None else key
|
||||
value = torch.cat([self.v_cache, value], dim=2) if self.v_cache is not None else value
|
||||
|
||||
@@ -210,7 +208,7 @@ class GlmImageDecoderAttnProcessor:
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class GlmImageDecoderTransformerBlock(nn.Module):
|
||||
class GlmImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2560,
|
||||
@@ -221,7 +219,7 @@ class GlmImageDecoderTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
# 1. Attention
|
||||
self.norm1 = GlmImageDecoderAdaLayerNormZero(time_embed_dim, dim)
|
||||
self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
@@ -231,7 +229,7 @@ class GlmImageDecoderTransformerBlock(nn.Module):
|
||||
qk_norm="layer_norm",
|
||||
elementwise_affine=False,
|
||||
eps=1e-5,
|
||||
processor=GlmImageDecoderAttnProcessor(),
|
||||
processor=GlmImageAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Feedforward
|
||||
@@ -242,7 +240,7 @@ class GlmImageDecoderTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
glyph_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
@@ -257,42 +255,42 @@ class GlmImageDecoderTransformerBlock(nn.Module):
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
norm_glyph_hidden_states,
|
||||
norm_encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
) = self.norm1(hidden_states, glyph_hidden_states, temb)
|
||||
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# 2. Attention
|
||||
if attention_kwargs is None:
|
||||
attention_kwargs = {}
|
||||
|
||||
attn_hidden_states, attn_glyph_hidden_states = self.attn1(
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_glyph_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
**attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
glyph_hidden_states = glyph_hidden_states + attn_glyph_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
|
||||
# 3. Feedforward
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
norm_glyph_hidden_states = self.norm2_context(glyph_hidden_states) * (
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
|
||||
1 + c_scale_mlp.unsqueeze(1)
|
||||
) + c_shift_mlp.unsqueeze(1)
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output_context = self.ff(norm_glyph_hidden_states)
|
||||
ff_output_context = self.ff(norm_encoder_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
|
||||
glyph_hidden_states = glyph_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
|
||||
|
||||
return hidden_states, glyph_hidden_states
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class GlmImageDecoderRotaryPosEmbed(nn.Module):
|
||||
class GlmImageRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -331,10 +329,10 @@ class GlmImageDecoderRotaryPosEmbed(nn.Module):
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class GlmImageDecoderAdaLayerNormContinuous(nn.Module):
|
||||
class GlmImageAdaLayerNormContinuous(nn.Module):
|
||||
"""
|
||||
GlmImageDecoder-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation**
|
||||
before the Linear on conditioning embedding.
|
||||
GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
|
||||
Linear on conditioning embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -363,7 +361,7 @@ class GlmImageDecoderAdaLayerNormContinuous(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
r"""
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
@@ -397,9 +395,9 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"GlmImageDecoderTransformerBlock",
|
||||
"GlmImageDecoderImageProjector",
|
||||
"GlmImageDecoderImageProjector",
|
||||
"GlmImageTransformerBlock",
|
||||
"GlmImageImageProjector",
|
||||
"GlmImageImageProjector",
|
||||
]
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
|
||||
|
||||
@@ -412,35 +410,30 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
num_layers: int = 30,
|
||||
attention_head_dim: int = 40,
|
||||
num_attention_heads: int = 64,
|
||||
text_embed_dim: int = 4096,
|
||||
glyph_embed_dim: int = 1472,
|
||||
text_embed_dim: int = 1472,
|
||||
time_embed_dim: int = 512,
|
||||
condition_dim: int = 256,
|
||||
pos_embed_max_size: int = 128,
|
||||
sample_size: int = 128,
|
||||
prior_vq_quantizer_codebook_size: int = 16384,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# GlmImageDecoder uses 2 additional SDXL-like conditions - target_size, crop_coords
|
||||
# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
|
||||
# Each of these are sincos embeddings of shape 2 * condition_dim
|
||||
pooled_projection_dim = 2 * 2 * condition_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels
|
||||
|
||||
# 1. RoPE
|
||||
self.rope = GlmImageDecoderRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
|
||||
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
|
||||
|
||||
# 2. Patch & Text-timestep embedding
|
||||
self.image_projector = GlmImageDecoderImageProjector(in_channels, inner_dim, patch_size)
|
||||
# 这次没有,未来可能有text_projector
|
||||
# self.text_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu")
|
||||
self.glyph_projector = FeedForward(glyph_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
|
||||
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
|
||||
self.text_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
|
||||
|
||||
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
|
||||
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
|
||||
|
||||
self.time_condition_embed = GlmImageDecoderCombinedTimestepSizeEmbeddings(
|
||||
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim=time_embed_dim,
|
||||
condition_dim=condition_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
@@ -450,13 +443,13 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
GlmImageDecoderTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
|
||||
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output projection
|
||||
self.norm_out = GlmImageDecoderAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
||||
self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -464,11 +457,10 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
glyph_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
prior_token_id: torch.Tensor,
|
||||
prior_token_drop: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
original_size: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
@@ -490,7 +482,7 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
post_patch_width = width // p
|
||||
|
||||
hidden_states = self.image_projector(hidden_states)
|
||||
glyph_hidden_states = self.glyph_projector(glyph_hidden_states)
|
||||
encoder_hidden_states = self.text_projector(encoder_hidden_states)
|
||||
prior_embedding = self.prior_token_embedding(prior_token_id)
|
||||
prior_embedding[prior_token_drop] *= 0.0
|
||||
prior_hidden_states = self.prior_projector(prior_embedding)
|
||||
@@ -503,19 +495,19 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
# 3. Transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, glyph_hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
glyph_hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, glyph_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
glyph_hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
@@ -534,7 +526,7 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def set_attention_processors_state(self, state: GlmImageDecoderAttenProcessorState):
|
||||
def set_attention_processors_state(self, state: GlmImageAttenProcessorState):
|
||||
for block in self.transformer_blocks:
|
||||
block.attn1.processor.processor_state = state
|
||||
|
||||
@@ -542,212 +534,3 @@ class GlmImageDecoderTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
for block in self.transformer_blocks:
|
||||
block.attn1.processor.k_cache = None
|
||||
block.attn1.processor.v_cache = None
|
||||
|
||||
def repeat_attention_processors_cache(self, repeats: int):
|
||||
for block in self.transformer_blocks:
|
||||
if block.attn1.processor.k_cache is None or block.attn1.processor.v_cache is None:
|
||||
continue
|
||||
block.attn1.processor.k_cache = torch.repeat_interleave(block.attn1.processor.k_cache, repeats, dim=2)
|
||||
block.attn1.processor.v_cache = torch.repeat_interleave(block.attn1.processor.v_cache, repeats, dim=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def swap_scale_shift(weight, dim):
|
||||
"""
|
||||
Swap the scale and shift components in the weight tensor.
|
||||
|
||||
Args:
|
||||
weight (torch.Tensor): The original weight tensor.
|
||||
dim (int): The dimension along which to split.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The modified weight tensor with scale and shift swapped.
|
||||
"""
|
||||
shift, scale = weight.chunk(2, dim=dim)
|
||||
new_weight = torch.cat([scale, shift], dim=dim)
|
||||
return new_weight
|
||||
|
||||
def convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
hidden_size: int,
|
||||
):
|
||||
"""
|
||||
Convert a Megatron Transformer checkpoint to Diffusers format.
|
||||
|
||||
Args:
|
||||
ckpt_path (str): Path to the Megatron Transformer checkpoint.
|
||||
num_layers (int): Number of Transformer layers.
|
||||
num_heads (int): Number of attention heads.
|
||||
hidden_size (int): Hidden size of the Transformer.
|
||||
|
||||
Returns:
|
||||
dict: The converted state dictionary compatible with Diffusers.
|
||||
"""
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
||||
mega = ckpt["model"]
|
||||
used_keys = set()
|
||||
|
||||
def get_mega(key):
|
||||
used_keys.add(key)
|
||||
return mega[key]
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Patch Embedding
|
||||
new_state_dict["image_projector.proj.weight"] = get_mega("encoder_expand_linear.weight").reshape(
|
||||
hidden_size, 64
|
||||
)
|
||||
new_state_dict["image_projector.proj.bias"] = get_mega("encoder_expand_linear.bias")
|
||||
|
||||
new_state_dict["glyph_projector.net.0.proj.weight"] = get_mega("glyph_projector.linear_fc1.weight")
|
||||
new_state_dict["glyph_projector.net.0.proj.bias"] = get_mega("glyph_projector.linear_fc1.bias")
|
||||
new_state_dict["glyph_projector.net.2.weight"] = get_mega("glyph_projector.linear_fc2.weight")
|
||||
new_state_dict["glyph_projector.net.2.bias"] = get_mega("glyph_projector.linear_fc2.bias")
|
||||
|
||||
new_state_dict["prior_token_embedding.weight"] = get_mega("xomni_token_id_embedding.weight")
|
||||
new_state_dict["prior_projector.net.0.proj.weight"] = get_mega("prior_condition_embedding.0.weight")
|
||||
new_state_dict["prior_projector.net.0.proj.bias"] = get_mega("prior_condition_embedding.0.bias")
|
||||
new_state_dict["prior_projector.net.2.weight"] = get_mega("prior_condition_embedding.2.weight")
|
||||
new_state_dict["prior_projector.net.2.bias"] = get_mega("prior_condition_embedding.2.bias")
|
||||
|
||||
# Time Condition Embedding
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = get_mega(
|
||||
"time_embedding.time_embed.0.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = get_mega(
|
||||
"time_embedding.time_embed.0.bias"
|
||||
)
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = get_mega(
|
||||
"time_embedding.time_embed.2.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = get_mega(
|
||||
"time_embedding.time_embed.2.bias"
|
||||
)
|
||||
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = get_mega(
|
||||
"label_embedding.label_embed.0.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = get_mega(
|
||||
"label_embedding.label_embed.0.bias"
|
||||
)
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = get_mega(
|
||||
"label_embedding.label_embed.2.weight"
|
||||
)
|
||||
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = get_mega(
|
||||
"label_embedding.label_embed.2.bias"
|
||||
)
|
||||
|
||||
# Convert each Transformer layer
|
||||
from tqdm import tqdm
|
||||
|
||||
for i in tqdm(range(num_layers), desc="Converting layers (Megatron->Diffusers)"):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
|
||||
# AdaLayerNorm
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = get_mega(f"decoder.layers.{i}.adaln.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = get_mega(f"decoder.layers.{i}.adaln.bias")
|
||||
qkv_weight = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.weight")
|
||||
qkv_bias = get_mega(f"decoder.layers.{i}.self_attention.linear_qkv.bias")
|
||||
|
||||
# Reshape to match SAT logic
|
||||
qkv_weight = qkv_weight.view(num_heads, 3, hidden_size // num_heads, hidden_size)
|
||||
qkv_weight = qkv_weight.permute(1, 0, 2, 3).reshape(3 * hidden_size, hidden_size)
|
||||
|
||||
qkv_bias = qkv_bias.view(num_heads, 3, hidden_size // num_heads)
|
||||
qkv_bias = qkv_bias.permute(1, 0, 2).reshape(3 * hidden_size)
|
||||
|
||||
# Assign to Diffusers keys
|
||||
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
||||
qb, kb, vb = torch.chunk(qkv_bias, 3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_q.bias"] = qb
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_k.bias"] = kb
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.to_v.bias"] = vb
|
||||
|
||||
# Attention Output
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = get_mega(
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = get_mega(
|
||||
f"decoder.layers.{i}.self_attention.linear_proj.bias"
|
||||
)
|
||||
|
||||
# MLP
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = get_mega(
|
||||
f"decoder.layers.{i}.mlp.linear_fc1.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc1.bias")
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.weight")
|
||||
new_state_dict[block_prefix + "ff.net.2.bias"] = get_mega(f"decoder.layers.{i}.mlp.linear_fc2.bias")
|
||||
|
||||
# Final Layers
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(get_mega("adaln_final.weight"), dim=0)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(get_mega("adaln_final.bias"), dim=0)
|
||||
new_state_dict["proj_out.weight"] = get_mega("output_projector.weight")
|
||||
new_state_dict["proj_out.bias"] = get_mega("output_projector.bias")
|
||||
|
||||
# Check for unused keys
|
||||
all_keys = set(mega.keys())
|
||||
unused_keys = all_keys - used_keys
|
||||
if unused_keys:
|
||||
print(f"\n[WARNING] The following {len(unused_keys)} keys in mega were NOT used:")
|
||||
for key in sorted(unused_keys):
|
||||
print(f" - {key}")
|
||||
raise ValueError(
|
||||
f"Found {len(unused_keys)} unused keys in Megatron checkpoint. Please update the conversion script to handle these keys."
|
||||
)
|
||||
else:
|
||||
print(f"\n[INFO] All {len(all_keys)} keys in mega were successfully used.")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
transformer = GlmImageDecoderTransformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
num_layers=30,
|
||||
attention_head_dim=128,
|
||||
num_attention_heads=32,
|
||||
out_channels=16,
|
||||
text_embed_dim=4096,
|
||||
time_embed_dim=512,
|
||||
glyph_embed_dim=1472,
|
||||
condition_dim=256,
|
||||
pos_embed_max_size=128,
|
||||
).to(torch.bfloat16)
|
||||
converted_transformer_state_dict = convert_megatron_transformer_checkpoint_to_diffusers(
|
||||
ckpt_path="/workspace/ckpt/tjy/Glm-train-dev/examples/cogview/ckpts/merge/1+6_0.5+0.5/iter_0000000/mp_rank_00/model_optim_rng.pt",
|
||||
num_layers=30,
|
||||
num_heads=32,
|
||||
hidden_size=4096,
|
||||
)
|
||||
transformer.load_state_dict(converted_transformer_state_dict)
|
||||
transformer.cuda()
|
||||
|
||||
latent = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/latent.pt").to(torch.bfloat16)
|
||||
latent = rearrange(latent, "(b h w) (c p q) -> b c (h p) (w q)", b=8, h=72, w=54, p=2, q=2)
|
||||
glyph_hidden_states = torch.load(
|
||||
"/workspace/ckpt/tjy/glm-train-dev/examples/cogview/glyph_condition_embedding.pt"
|
||||
).to(torch.bfloat16)
|
||||
glyph_hidden_states = rearrange(glyph_hidden_states, "(b n) c -> b n c", b=8, n=2)
|
||||
prior_token_id = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_token_id.pt")
|
||||
prior_token_drop = torch.load("/workspace/ckpt/tjy/glm-train-dev/examples/cogview/xomni_drop.pt")
|
||||
prior_token_id = rearrange(prior_token_id, "(b n) -> b n", b=8)
|
||||
prior_token_drop = rearrange(prior_token_drop, "(b n)-> b n", b=8)
|
||||
|
||||
with torch.no_grad():
|
||||
output = transformer(
|
||||
hidden_states=latent,
|
||||
glyph_hidden_states=glyph_hidden_states,
|
||||
prior_token_id=prior_token_id,
|
||||
prior_token_drop=prior_token_drop,
|
||||
timestep=torch.tensor([999.0] * 8).cuda(),
|
||||
original_size=torch.tensor([[144, 108]] * 8).cuda(),
|
||||
target_size=torch.tensor([[144, 108]] * 8).cuda(),
|
||||
crop_coords=torch.tensor([[0, 0]] * 8).cuda(),
|
||||
)
|
||||
|
||||
@@ -52,7 +52,7 @@ from .flux import (
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
from .glm_image import GlmImageDecoderPipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
@@ -168,7 +168,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("chroma", ChromaPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("glm_image", GlmImageDecoderPipeline),
|
||||
("glm_image", GlmImagePipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
|
||||
@@ -12,7 +12,7 @@ from ...utils import (
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["GlmImageDecoderPipelineOutput"]}
|
||||
_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_glm_image"] = ["GlmImageDecoderPipeline"]
|
||||
_import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -30,7 +30,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_glm_image import GlmImageDecoderPipeline
|
||||
from .pipeline_glm_image import GlmImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -25,13 +25,13 @@ from transformers import AutoTokenizer, T5EncoderModel
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import CogView4LoraLoaderMixin
|
||||
from ...models import AutoencoderKL, GlmImageDecoderTransformer2DModel
|
||||
from ...models.transformers.transformer_glm_image import GlmImageDecoderAttenProcessorState
|
||||
from ...models import AutoencoderKL, GlmImageTransformer2DModel
|
||||
from ...models.transformers.transformer_glm_image import GlmImageAttenProcessorState
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import GlmImageDecoderPipelineOutput
|
||||
from .pipeline_output import GlmImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -47,9 +47,9 @@ EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import GlmImageDecoderPipeline
|
||||
>>> from diffusers import GlmImagePipeline
|
||||
|
||||
>>> pipe = GlmImageDecoderPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
||||
>>> pipe = GlmImagePipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
||||
@@ -151,7 +151,7 @@ def retrieve_latents(
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using CogView4.
|
||||
|
||||
@@ -162,7 +162,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`GlmModel`]):
|
||||
Frozen text-encoder. CogView4 uses [Glm-4-9b-hf](https://huggingface.co/THUDM/Glm-4-9b-hf).
|
||||
Frozen text-encoder. CogView4 uses [GLM-Image](https://huggingface.co/zai-org/GLM-Image).
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
Tokenizer of class
|
||||
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
|
||||
@@ -179,15 +179,15 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
glyph_encoder: T5EncoderModel,
|
||||
text_encoder: T5EncoderModel,
|
||||
vae: AutoencoderKL,
|
||||
transformer: GlmImageDecoderTransformer2DModel,
|
||||
transformer: GlmImageTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer, glyph_encoder=glyph_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
@@ -221,7 +221,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.glyph_encoder.dtype
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
glyph_texts = self.get_glyph_texts(prompt)
|
||||
input_ids = self.tokenizer(
|
||||
@@ -240,7 +240,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.glyph_encoder(input_ids, attention_mask=attention_mask)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
|
||||
return glyph_embeds.to(device=device, dtype=dtype)
|
||||
@@ -442,7 +442,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 2048,
|
||||
) -> Union[GlmImageDecoderPipelineOutput, Tuple]:
|
||||
) -> Union[GlmImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
@@ -615,7 +615,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
)
|
||||
|
||||
if condition_images is not None:
|
||||
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditWriteKV)
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditWriteKV)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
@@ -698,7 +698,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if condition_images is not None:
|
||||
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageEditReadKV)
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageEditReadKV)
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -717,7 +717,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
if self.do_classifier_free_guidance:
|
||||
if condition_images is not None:
|
||||
self.transformer.set_attention_processors_state(
|
||||
GlmImageDecoderAttenProcessorState.ImageEditDontReadKV
|
||||
GlmImageAttenProcessorState.ImageEditDontReadKV
|
||||
)
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -755,7 +755,7 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
self.transformer.set_attention_processors_state(GlmImageDecoderAttenProcessorState.ImageGen)
|
||||
self.transformer.set_attention_processors_state(GlmImageAttenProcessorState.ImageGen)
|
||||
self.transformer.clear_attention_processors_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
@@ -783,4 +783,4 @@ class GlmImageDecoderPipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
if not return_dict:
|
||||
return (condition_images,)
|
||||
|
||||
return GlmImageDecoderPipelineOutput(images=condition_images)
|
||||
return GlmImagePipelineOutput(images=condition_images)
|
||||
|
||||
@@ -8,7 +8,7 @@ from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlmImageDecoderPipelineOutput(BaseOutput):
|
||||
class GlmImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for CogView3 pipelines.
|
||||
|
||||
|
||||
@@ -967,7 +967,7 @@ class HiDreamImageTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class GlmImageDecoderTransformer2DModel(metaclass=DummyObject):
|
||||
class GlmImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user