1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2024-10-24 05:47:03 +02:00
parent 98a4554ac6
commit 85c8734cdc
5 changed files with 39 additions and 18 deletions

View File

@@ -4,8 +4,8 @@ from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
# from transformers import T5EncoderModel, T5Tokenizer
# from transformers import T5EncoderModel, T5Tokenizer
from diffusers import MochiTransformer3DModel
from diffusers.utils.import_utils import is_accelerate_available
@@ -72,10 +72,12 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
old_prefix + "mod_y.bias"
)
else:
new_state_dict[block_prefix + "norm1_context.weight"] = original_state_dict.pop(
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.bias"] = original_state_dict.pop(old_prefix + "mod_y.bias")
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
old_prefix + "mod_y.bias"
)
# Visual attention
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
@@ -158,7 +160,7 @@ def main(args):
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
vae = None
# vae = None
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(

View File

@@ -1794,7 +1794,9 @@ class FluxAttnProcessor2_0:
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:

View File

@@ -1304,7 +1304,12 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self, embedding_dim: int, pooled_projection_dim: int, text_embed_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8
self,
embedding_dim: int,
pooled_projection_dim: int,
text_embed_dim: int,
time_embed_dim: int = 256,
num_attention_heads: int = 8,
) -> None:
super().__init__()

View File

@@ -385,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module):
out_dim: Optional[int] = None,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
if norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
# linear_2
self.linear_2 = None
if out_dim is not None:
self.linear_2 = nn.Linear(
embedding_dim,
out_dim,
bias=bias,
)
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
def forward(
self,

View File

@@ -26,7 +26,7 @@ from ..attention_processor import Attention, FluxAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -55,7 +55,14 @@ class MochiTransformerBlock(nn.Module):
if not context_pre_only:
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim)
else:
self.norm1_context = nn.Linear(dim, pooled_projection_dim)
self.norm1_context = LuminaLayerNormContinuous(
embedding_dim=pooled_projection_dim,
conditioning_embedding_dim=dim,
eps=1e-6,
elementwise_affine=False,
norm_type="rms_norm",
out_dim=None,
)
self.attn1 = Attention(
query_dim=dim,
@@ -83,7 +90,9 @@ class MochiTransformerBlock(nn.Module):
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
self.ff_context = None
if not context_pre_only:
self.ff_context = FeedForward(pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False)
self.ff_context = FeedForward(
pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False
)
self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
@@ -102,7 +111,7 @@ class MochiTransformerBlock(nn.Module):
encoder_hidden_states, temb
)
else:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
attn_hidden_states, context_attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
@@ -112,7 +121,7 @@ class MochiTransformerBlock(nn.Module):
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states
@@ -207,7 +216,9 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
post_patch_height = height // p
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype)
temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.patch_embed(hidden_states)