From 5f43c6a41f607a45671ebdbc10e255e4b1bb4351 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 25 Oct 2024 10:27:28 +0200 Subject: [PATCH] docs --- docs/source/en/_toctree.yml | 4 ++ .../en/api/models/autoencoderkl_mochi.md | 33 ++++++++++ docs/source/en/api/pipelines/mochi.md | 36 +++++++++++ .../models/transformers/transformer_mochi.py | 63 ++++++++++++++++++- 4 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 docs/source/en/api/models/autoencoderkl_mochi.md create mode 100644 docs/source/en/api/pipelines/mochi.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 97d4630d08..6dc89c3eac 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -302,6 +302,8 @@ title: AutoencoderKL - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoderkl_mochi + title: AutoencoderKLMochi - local: api/models/asymmetricautoencoderkl title: AsymmetricAutoencoderKL - local: api/models/consistency_decoder_vae @@ -394,6 +396,8 @@ title: Lumina-T2X - local: api/pipelines/marigold title: Marigold + - local: api/pipelines/mochi + title: Mochi - local: api/pipelines/panorama title: MultiDiffusion - local: api/pipelines/musicldm diff --git a/docs/source/en/api/models/autoencoderkl_mochi.md b/docs/source/en/api/models/autoencoderkl_mochi.md new file mode 100644 index 0000000000..7498182f8b --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_mochi.md @@ -0,0 +1,33 @@ + + +# AutoencoderKLMochi + +The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLMochi + +vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLMochi + +[[autodoc]] AutoencoderKLMochi + - decode + - encode + - all + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/pipelines/mochi.md b/docs/source/en/api/pipelines/mochi.md new file mode 100644 index 0000000000..f29297e590 --- /dev/null +++ b/docs/source/en/api/pipelines/mochi.md @@ -0,0 +1,36 @@ + + +# Mochi + +[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo. + +*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## MochiPipeline + +[[autodoc]] MochiPipeline + - all + - __call__ + +## MochiPipelineOutput + +[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput diff --git a/src/diffusers/models/transformers/transformer_mochi.py b/src/diffusers/models/transformers/transformer_mochi.py index 660da3bc14..0eb15ffdfd 100644 --- a/src/diffusers/models/transformers/transformer_mochi.py +++ b/src/diffusers/models/transformers/transformer_mochi.py @@ -34,6 +34,26 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph class MochiTransformerBlock(nn.Module): + r""" + Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + context_pre_only (`bool`, defaults to `False`): + Whether or not to process context-related conditions with additional layers. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + def __init__( self, dim: int, @@ -42,7 +62,7 @@ class MochiTransformerBlock(nn.Module): pooled_projection_dim: int, qk_norm: str = "rms_norm", activation_fn: str = "swiglu", - context_pre_only: bool = True, + context_pre_only: bool = False, eps: float = 1e-6, ) -> None: super().__init__() @@ -82,6 +102,7 @@ class MochiTransformerBlock(nn.Module): elementwise_affine=True, ) + # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) @@ -145,7 +166,17 @@ class MochiTransformerBlock(nn.Module): class MochiRoPE(nn.Module): - def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None: + r""" + RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + base_height (`int`, defaults to `192`): + Base height used to compute interpolation scale for rotary positional embeddings. + base_width (`int`, defaults to `192`): + Base width used to compute interpolation scale for rotary positional embeddings. + """ + + def __init__(self, base_height: int = 192, base_width: int = 192) -> None: super().__init__() self.target_area = base_height * base_width @@ -195,6 +226,34 @@ class MochiRoPE(nn.Module): @maybe_allow_in_graph class MochiTransformer3DModel(ModelMixin, ConfigMixin): + r""" + A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview). + + Args: + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + num_layers (`int`, defaults to `48`): + The number of layers of Transformer blocks to use. + in_channels (`int`, defaults to `12`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + time_embed_dim (`int`, defaults to `256`): + Output dimension of timestep embeddings. + activation_fn (`str`, defaults to `"swiglu"`): + Activation function to use in feed-forward. + max_sequence_length (`int`, defaults to `256`): + The maximum sequence length of text embeddings supported. + """ + _supports_gradient_checkpointing = True @register_to_config