mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
docs
This commit is contained in:
@@ -302,6 +302,8 @@
|
||||
title: AutoencoderKL
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/consistency_decoder_vae
|
||||
@@ -394,6 +396,8 @@
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
title: Marigold
|
||||
- local: api/pipelines/mochi
|
||||
title: Mochi
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
|
||||
33
docs/source/en/api/models/autoencoderkl_mochi.md
Normal file
33
docs/source/en/api/models/autoencoderkl_mochi.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLMochi
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLMochi
|
||||
|
||||
vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLMochi
|
||||
|
||||
[[autodoc]] AutoencoderKLMochi
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
36
docs/source/en/api/pipelines/mochi.md
Normal file
36
docs/source/en/api/pipelines/mochi.md
Normal file
@@ -0,0 +1,36 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# Mochi
|
||||
|
||||
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
|
||||
|
||||
*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## MochiPipeline
|
||||
|
||||
[[autodoc]] MochiPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## MochiPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput
|
||||
@@ -34,6 +34,26 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
qk_norm (`str`, defaults to `"rms_norm"`):
|
||||
The normalization layer to use.
|
||||
activation_fn (`str`, defaults to `"swiglu"`):
|
||||
Activation function to use in feed-forward.
|
||||
context_pre_only (`bool`, defaults to `False`):
|
||||
Whether or not to process context-related conditions with additional layers.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -42,7 +62,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
pooled_projection_dim: int,
|
||||
qk_norm: str = "rms_norm",
|
||||
activation_fn: str = "swiglu",
|
||||
context_pre_only: bool = True,
|
||||
context_pre_only: bool = False,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -82,6 +102,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
elementwise_affine=True,
|
||||
)
|
||||
|
||||
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
@@ -145,7 +166,17 @@ class MochiTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class MochiRoPE(nn.Module):
|
||||
def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None:
|
||||
r"""
|
||||
RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
Args:
|
||||
base_height (`int`, defaults to `192`):
|
||||
Base height used to compute interpolation scale for rotary positional embeddings.
|
||||
base_width (`int`, defaults to `192`):
|
||||
Base width used to compute interpolation scale for rotary positional embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target_area = base_height * base_width
|
||||
@@ -195,6 +226,34 @@ class MochiRoPE(nn.Module):
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `48`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
in_channels (`int`, defaults to `12`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output.
|
||||
qk_norm (`str`, defaults to `"rms_norm"`):
|
||||
The normalization layer to use.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
time_embed_dim (`int`, defaults to `256`):
|
||||
Output dimension of timestep embeddings.
|
||||
activation_fn (`str`, defaults to `"swiglu"`):
|
||||
Activation function to use in feed-forward.
|
||||
max_sequence_length (`int`, defaults to `256`):
|
||||
The maximum sequence length of text embeddings supported.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
|
||||
Reference in New Issue
Block a user