diff --git a/scripts/convert_mochi_to_diffusers.py b/scripts/convert_mochi_to_diffusers.py index 83f642f65c..1d1d10a6ad 100644 --- a/scripts/convert_mochi_to_diffusers.py +++ b/scripts/convert_mochi_to_diffusers.py @@ -4,8 +4,8 @@ from contextlib import nullcontext import torch from accelerate import init_empty_weights from safetensors.torch import load_file -# from transformers import T5EncoderModel, T5Tokenizer +# from transformers import T5EncoderModel, T5Tokenizer from diffusers import MochiTransformer3DModel from diffusers.utils.import_utils import is_accelerate_available @@ -72,10 +72,12 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path): old_prefix + "mod_y.bias" ) else: - new_state_dict[block_prefix + "norm1_context.weight"] = original_state_dict.pop( + new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop( old_prefix + "mod_y.weight" ) - new_state_dict[block_prefix + "norm1_context.bias"] = original_state_dict.pop(old_prefix + "mod_y.bias") + new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop( + old_prefix + "mod_y.bias" + ) # Visual attention qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight") @@ -158,7 +160,7 @@ def main(args): raise ValueError(f"Unsupported dtype: {args.dtype}") transformer = None - vae = None + # vae = None if args.transformer_checkpoint_path is not None: converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers( diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cfbc2bc140..ce0f9d87c8 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1794,7 +1794,9 @@ class FluxAttnProcessor2_0: hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + if hasattr(attn, "to_add_out"): + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3788829f16..3cf808430c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1304,7 +1304,12 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module): class MochiCombinedTimestepCaptionEmbedding(nn.Module): def __init__( - self, embedding_dim: int, pooled_projection_dim: int, text_embed_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8 + self, + embedding_dim: int, + pooled_projection_dim: int, + text_embed_dim: int, + time_embed_dim: int = 256, + num_attention_heads: int = 8, ) -> None: super().__init__() diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index e11faee490..9058320998 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -385,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module): out_dim: Optional[int] = None, ): super().__init__() + # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) + if norm_type == "layer_norm": self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + if norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") - # linear_2 + + self.linear_2 = None if out_dim is not None: - self.linear_2 = nn.Linear( - embedding_dim, - out_dim, - bias=bias, - ) + self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 8aa7e48d3f..7ece241e4b 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -26,7 +26,7 @@ from ..attention_processor import Attention, FluxAttnProcessor2_0 from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm +from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,7 +55,14 @@ class MochiTransformerBlock(nn.Module): if not context_pre_only: self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim) else: - self.norm1_context = nn.Linear(dim, pooled_projection_dim) + self.norm1_context = LuminaLayerNormContinuous( + embedding_dim=pooled_projection_dim, + conditioning_embedding_dim=dim, + eps=1e-6, + elementwise_affine=False, + norm_type="rms_norm", + out_dim=None, + ) self.attn1 = Attention( query_dim=dim, @@ -83,7 +90,9 @@ class MochiTransformerBlock(nn.Module): self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff_context = None if not context_pre_only: - self.ff_context = FeedForward(pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False) + self.ff_context = FeedForward( + pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False + ) self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False) self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False) @@ -102,7 +111,7 @@ class MochiTransformerBlock(nn.Module): encoder_hidden_states, temb ) else: - norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) + norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) attn_hidden_states, context_attn_hidden_states = self.attn1( hidden_states=norm_hidden_states, @@ -112,7 +121,7 @@ class MochiTransformerBlock(nn.Module): hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) - + if not self.context_pre_only: encoder_hidden_states = encoder_hidden_states + self.norm2_context( context_attn_hidden_states @@ -207,7 +216,9 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin): post_patch_height = height // p post_patch_width = width // p - temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype) + temb, encoder_hidden_states = self.time_embed( + timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype + ) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = self.patch_embed(hidden_states)