1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Rename LTX blocks and docs title (#10213)

* rename blocks and docs

* fix docs

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Aryan
2024-12-23 15:29:10 +05:30
committed by GitHub
parent 055d95543a
commit 9d27df8071
5 changed files with 49 additions and 48 deletions

View File

@@ -429,7 +429,7 @@
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTX
title: LTXVideo
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold

View File

@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTXVideo
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTXVideo

View File

@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
```python
from diffusers import LTXVideoTransformer3DModel
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTXVideoTransformer3DModel

View File

@@ -28,7 +28,7 @@ from ..normalization import RMSNorm
from .vae import DecoderOutput, DiagonalGaussianDistribution
class LTXCausalConv3d(nn.Module):
class LTXVideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
@@ -79,9 +79,9 @@ class LTXCausalConv3d(nn.Module):
return hidden_states
class LTXResnetBlock3d(nn.Module):
class LTXVideoResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX model.
A 3D ResNet block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -117,13 +117,13 @@ class LTXResnetBlock3d(nn.Module):
self.nonlinearity = get_activation(non_linearity)
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.conv1 = LTXCausalConv3d(
self.conv1 = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.dropout = nn.Dropout(dropout)
self.conv2 = LTXCausalConv3d(
self.conv2 = LTXVideoCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
)
@@ -131,7 +131,7 @@ class LTXResnetBlock3d(nn.Module):
self.conv_shortcut = None
if in_channels != out_channels:
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
self.conv_shortcut = LTXCausalConv3d(
self.conv_shortcut = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
)
@@ -157,7 +157,7 @@ class LTXResnetBlock3d(nn.Module):
return hidden_states
class LTXUpsampler3d(nn.Module):
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
@@ -170,7 +170,7 @@ class LTXUpsampler3d(nn.Module):
out_channels = in_channels * stride[0] * stride[1] * stride[2]
self.conv = LTXCausalConv3d(
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
@@ -191,9 +191,9 @@ class LTXUpsampler3d(nn.Module):
return hidden_states
class LTXDownBlock3D(nn.Module):
class LTXVideoDownBlock3D(nn.Module):
r"""
Down block used in the LTX model.
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -235,7 +235,7 @@ class LTXDownBlock3D(nn.Module):
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
@@ -250,7 +250,7 @@ class LTXDownBlock3D(nn.Module):
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList(
[
LTXCausalConv3d(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
@@ -262,7 +262,7 @@ class LTXDownBlock3D(nn.Module):
self.conv_out = None
if in_channels != out_channels:
self.conv_out = LTXResnetBlock3d(
self.conv_out = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
@@ -300,9 +300,9 @@ class LTXDownBlock3D(nn.Module):
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXMidBlock3d(nn.Module):
class LTXVideoMidBlock3d(nn.Module):
r"""
A middle block used in the LTX model.
A middle block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -335,7 +335,7 @@ class LTXMidBlock3d(nn.Module):
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
@@ -367,9 +367,9 @@ class LTXMidBlock3d(nn.Module):
return hidden_states
class LTXUpBlock3d(nn.Module):
class LTXVideoUpBlock3d(nn.Module):
r"""
Up block used in the LTX model.
Up block used in the LTXVideo model.
Args:
in_channels (`int`):
@@ -410,7 +410,7 @@ class LTXUpBlock3d(nn.Module):
self.conv_in = None
if in_channels != out_channels:
self.conv_in = LTXResnetBlock3d(
self.conv_in = LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=out_channels,
dropout=dropout,
@@ -421,12 +421,12 @@ class LTXUpBlock3d(nn.Module):
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
resnets = []
for _ in range(num_layers):
resnets.append(
LTXResnetBlock3d(
LTXVideoResnetBlock3d(
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
@@ -463,9 +463,9 @@ class LTXUpBlock3d(nn.Module):
return hidden_states
class LTXEncoder3d(nn.Module):
class LTXVideoEncoder3d(nn.Module):
r"""
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
representation.
Args:
@@ -509,7 +509,7 @@ class LTXEncoder3d(nn.Module):
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
@@ -524,7 +524,7 @@ class LTXEncoder3d(nn.Module):
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
down_block = LTXDownBlock3D(
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
@@ -536,7 +536,7 @@ class LTXEncoder3d(nn.Module):
self.down_blocks.append(down_block)
# mid block
self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
@@ -546,14 +546,14 @@ class LTXEncoder3d(nn.Module):
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""The forward method of the `LTXEncoder3D` class."""
r"""The forward method of the `LTXVideoEncoder3d` class."""
p = self.patch_size
p_t = self.patch_size_t
@@ -599,9 +599,10 @@ class LTXEncoder3d(nn.Module):
return hidden_states
class LTXDecoder3d(nn.Module):
class LTXVideoDecoder3d(nn.Module):
r"""
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, defaults to 128):
@@ -647,11 +648,11 @@ class LTXDecoder3d(nn.Module):
layers_per_block = tuple(reversed(layers_per_block))
output_channel = block_out_channels[0]
self.conv_in = LTXCausalConv3d(
self.conv_in = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
)
self.mid_block = LTXMidBlock3d(
self.mid_block = LTXVideoMidBlock3d(
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
)
@@ -662,7 +663,7 @@ class LTXDecoder3d(nn.Module):
input_channel = output_channel
output_channel = block_out_channels[i]
up_block = LTXUpBlock3d(
up_block = LTXVideoUpBlock3d(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i + 1],
@@ -676,7 +677,7 @@ class LTXDecoder3d(nn.Module):
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXCausalConv3d(
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
)
@@ -777,7 +778,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
) -> None:
super().__init__()
self.encoder = LTXEncoder3d(
self.encoder = LTXVideoEncoder3d(
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
@@ -788,7 +789,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
resnet_norm_eps=resnet_norm_eps,
is_causal=encoder_causal,
)
self.decoder = LTXDecoder3d(
self.decoder = LTXVideoDecoder3d(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
@@ -837,7 +838,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = 448
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
module.gradient_checkpointing = value
def enable_tiling(

View File

@@ -35,7 +35,7 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXAttentionProcessor2_0:
class LTXVideoAttentionProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
@@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
@@ -92,7 +92,7 @@ class LTXAttentionProcessor2_0:
return hidden_states
class LTXRotaryPosEmbed(nn.Module):
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
dim: int,
@@ -164,7 +164,7 @@ class LTXRotaryPosEmbed(nn.Module):
@maybe_allow_in_graph
class LTXTransformerBlock(nn.Module):
class LTXVideoTransformerBlock(nn.Module):
r"""
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -208,7 +208,7 @@ class LTXTransformerBlock(nn.Module):
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -221,7 +221,7 @@ class LTXTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
processor=LTXAttentionProcessor2_0(),
processor=LTXVideoAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -327,7 +327,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.rope = LTXRotaryPosEmbed(
self.rope = LTXVideoRotaryPosEmbed(
dim=inner_dim,
base_num_frames=20,
base_height=2048,
@@ -339,7 +339,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
self.transformer_blocks = nn.ModuleList(
[
LTXTransformerBlock(
LTXVideoTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,