mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
2
This commit is contained in:
@@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
@@ -376,7 +375,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
The number of heads to use for multi-head attention.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
text_embed_dim (`int`, defaults to `1472`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
@@ -428,8 +427,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
|
||||
# 2. Patch & Text-timestep embedding
|
||||
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
|
||||
self.text_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
|
||||
|
||||
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, activation_fn="gelu")
|
||||
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
|
||||
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
|
||||
|
||||
@@ -482,7 +480,7 @@ class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
post_patch_width = width // p
|
||||
|
||||
hidden_states = self.image_projector(hidden_states)
|
||||
encoder_hidden_states = self.text_projector(encoder_hidden_states)
|
||||
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
|
||||
prior_embedding = self.prior_token_embedding(prior_token_id)
|
||||
prior_embedding[prior_token_drop] *= 0.0
|
||||
prior_hidden_states = self.prior_projector(prior_embedding)
|
||||
|
||||
Reference in New Issue
Block a user