mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix
This commit is contained in:
@@ -4,8 +4,8 @@ from contextlib import nullcontext
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
# from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
# from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import MochiTransformer3DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
@@ -72,10 +72,12 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
else:
|
||||
new_state_dict[block_prefix + "norm1_context.weight"] = original_state_dict.pop(
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.bias"] = original_state_dict.pop(old_prefix + "mod_y.bias")
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
|
||||
# Visual attention
|
||||
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
|
||||
@@ -158,7 +160,7 @@ def main(args):
|
||||
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
||||
|
||||
transformer = None
|
||||
vae = None
|
||||
# vae = None
|
||||
|
||||
if args.transformer_checkpoint_path is not None:
|
||||
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
|
||||
|
||||
@@ -1794,7 +1794,9 @@ class FluxAttnProcessor2_0:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -1304,7 +1304,12 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
||||
|
||||
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self, embedding_dim: int, pooled_projection_dim: int, text_embed_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8
|
||||
self,
|
||||
embedding_dim: int,
|
||||
pooled_projection_dim: int,
|
||||
text_embed_dim: int,
|
||||
time_embed_dim: int = 256,
|
||||
num_attention_heads: int = 8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -385,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module):
|
||||
out_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# AdaLN
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
# linear_2
|
||||
|
||||
self.linear_2 = None
|
||||
if out_dim is not None:
|
||||
self.linear_2 = nn.Linear(
|
||||
embedding_dim,
|
||||
out_dim,
|
||||
bias=bias,
|
||||
)
|
||||
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -26,7 +26,7 @@ from ..attention_processor import Attention, FluxAttnProcessor2_0
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm
|
||||
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -55,7 +55,14 @@ class MochiTransformerBlock(nn.Module):
|
||||
if not context_pre_only:
|
||||
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim)
|
||||
else:
|
||||
self.norm1_context = nn.Linear(dim, pooled_projection_dim)
|
||||
self.norm1_context = LuminaLayerNormContinuous(
|
||||
embedding_dim=pooled_projection_dim,
|
||||
conditioning_embedding_dim=dim,
|
||||
eps=1e-6,
|
||||
elementwise_affine=False,
|
||||
norm_type="rms_norm",
|
||||
out_dim=None,
|
||||
)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
@@ -83,7 +90,9 @@ class MochiTransformerBlock(nn.Module):
|
||||
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
|
||||
self.ff_context = None
|
||||
if not context_pre_only:
|
||||
self.ff_context = FeedForward(pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False)
|
||||
self.ff_context = FeedForward(
|
||||
pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False
|
||||
)
|
||||
|
||||
self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
|
||||
self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
|
||||
@@ -102,7 +111,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
encoder_hidden_states, temb
|
||||
)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
|
||||
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
@@ -112,7 +121,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
|
||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
|
||||
|
||||
|
||||
if not self.context_pre_only:
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||
context_attn_hidden_states
|
||||
@@ -207,7 +216,9 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype)
|
||||
temb, encoder_hidden_states = self.time_embed(
|
||||
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user