mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Rename LTX blocks and docs title (#10213)
* rename blocks and docs * fix docs --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -429,7 +429,7 @@
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTX
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("Lightricks/LTX-Video", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTXVideo
|
||||
|
||||
@@ -18,7 +18,7 @@ The model can be loaded with the following code snippet.
|
||||
```python
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("Lightricks/LTX-Video", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTXVideoTransformer3DModel
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..normalization import RMSNorm
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class LTXCausalConv3d(nn.Module):
|
||||
class LTXVideoCausalConv3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -79,9 +79,9 @@ class LTXCausalConv3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXResnetBlock3d(nn.Module):
|
||||
class LTXVideoResnetBlock3d(nn.Module):
|
||||
r"""
|
||||
A 3D ResNet block used in the LTX model.
|
||||
A 3D ResNet block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -117,13 +117,13 @@ class LTXResnetBlock3d(nn.Module):
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.conv1 = LTXCausalConv3d(
|
||||
self.conv1 = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = LTXCausalConv3d(
|
||||
self.conv2 = LTXVideoCausalConv3d(
|
||||
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
|
||||
)
|
||||
|
||||
@@ -131,7 +131,7 @@ class LTXResnetBlock3d(nn.Module):
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True)
|
||||
self.conv_shortcut = LTXCausalConv3d(
|
||||
self.conv_shortcut = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
@@ -157,7 +157,7 @@ class LTXResnetBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXUpsampler3d(nn.Module):
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -170,7 +170,7 @@ class LTXUpsampler3d(nn.Module):
|
||||
|
||||
out_channels = in_channels * stride[0] * stride[1] * stride[2]
|
||||
|
||||
self.conv = LTXCausalConv3d(
|
||||
self.conv = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
@@ -191,9 +191,9 @@ class LTXUpsampler3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXDownBlock3D(nn.Module):
|
||||
class LTXVideoDownBlock3D(nn.Module):
|
||||
r"""
|
||||
Down block used in the LTX model.
|
||||
Down block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -235,7 +235,7 @@ class LTXDownBlock3D(nn.Module):
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXResnetBlock3d(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
@@ -250,7 +250,7 @@ class LTXDownBlock3D(nn.Module):
|
||||
if spatio_temporal_scale:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXCausalConv3d(
|
||||
LTXVideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
@@ -262,7 +262,7 @@ class LTXDownBlock3D(nn.Module):
|
||||
|
||||
self.conv_out = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_out = LTXResnetBlock3d(
|
||||
self.conv_out = LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
@@ -300,9 +300,9 @@ class LTXDownBlock3D(nn.Module):
|
||||
|
||||
|
||||
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
|
||||
class LTXMidBlock3d(nn.Module):
|
||||
class LTXVideoMidBlock3d(nn.Module):
|
||||
r"""
|
||||
A middle block used in the LTX model.
|
||||
A middle block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -335,7 +335,7 @@ class LTXMidBlock3d(nn.Module):
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXResnetBlock3d(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
@@ -367,9 +367,9 @@ class LTXMidBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXUpBlock3d(nn.Module):
|
||||
class LTXVideoUpBlock3d(nn.Module):
|
||||
r"""
|
||||
Up block used in the LTX model.
|
||||
Up block used in the LTXVideo model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
@@ -410,7 +410,7 @@ class LTXUpBlock3d(nn.Module):
|
||||
|
||||
self.conv_in = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_in = LTXResnetBlock3d(
|
||||
self.conv_in = LTXVideoResnetBlock3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
@@ -421,12 +421,12 @@ class LTXUpBlock3d(nn.Module):
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
|
||||
self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)])
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
resnets.append(
|
||||
LTXResnetBlock3d(
|
||||
LTXVideoResnetBlock3d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
dropout=dropout,
|
||||
@@ -463,9 +463,9 @@ class LTXUpBlock3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXEncoder3d(nn.Module):
|
||||
class LTXVideoEncoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
The `LTXVideoEncoder3d` layer of a variational autoencoder that encodes input video samples to its latent
|
||||
representation.
|
||||
|
||||
Args:
|
||||
@@ -509,7 +509,7 @@ class LTXEncoder3d(nn.Module):
|
||||
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=output_channel,
|
||||
kernel_size=3,
|
||||
@@ -524,7 +524,7 @@ class LTXEncoder3d(nn.Module):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
|
||||
|
||||
down_block = LTXDownBlock3D(
|
||||
down_block = LTXVideoDownBlock3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i],
|
||||
@@ -536,7 +536,7 @@ class LTXEncoder3d(nn.Module):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid block
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
in_channels=output_channel,
|
||||
num_layers=layers_per_block[-1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
@@ -546,14 +546,14 @@ class LTXEncoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""The forward method of the `LTXEncoder3D` class."""
|
||||
r"""The forward method of the `LTXVideoEncoder3d` class."""
|
||||
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
@@ -599,9 +599,10 @@ class LTXEncoder3d(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXDecoder3d(nn.Module):
|
||||
class LTXVideoDecoder3d(nn.Module):
|
||||
r"""
|
||||
The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
||||
The `LTXVideoDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output
|
||||
sample.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 128):
|
||||
@@ -647,11 +648,11 @@ class LTXDecoder3d(nn.Module):
|
||||
layers_per_block = tuple(reversed(layers_per_block))
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
self.conv_in = LTXCausalConv3d(
|
||||
self.conv_in = LTXVideoCausalConv3d(
|
||||
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
self.mid_block = LTXMidBlock3d(
|
||||
self.mid_block = LTXVideoMidBlock3d(
|
||||
in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal
|
||||
)
|
||||
|
||||
@@ -662,7 +663,7 @@ class LTXDecoder3d(nn.Module):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
|
||||
up_block = LTXUpBlock3d(
|
||||
up_block = LTXVideoUpBlock3d(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
num_layers=layers_per_block[i + 1],
|
||||
@@ -676,7 +677,7 @@ class LTXDecoder3d(nn.Module):
|
||||
# out
|
||||
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LTXCausalConv3d(
|
||||
self.conv_out = LTXVideoCausalConv3d(
|
||||
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
|
||||
)
|
||||
|
||||
@@ -777,7 +778,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = LTXEncoder3d(
|
||||
self.encoder = LTXVideoEncoder3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
@@ -788,7 +789,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
is_causal=encoder_causal,
|
||||
)
|
||||
self.decoder = LTXDecoder3d(
|
||||
self.decoder = LTXVideoDecoder3d(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
block_out_channels=block_out_channels,
|
||||
@@ -837,7 +838,7 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.tile_sample_stride_width = 448
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LTXEncoder3d, LTXDecoder3d)):
|
||||
if isinstance(module, (LTXVideoEncoder3d, LTXVideoDecoder3d)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(
|
||||
|
||||
@@ -35,7 +35,7 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LTXAttentionProcessor2_0:
|
||||
class LTXVideoAttentionProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
@@ -44,7 +44,7 @@ class LTXAttentionProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
@@ -92,7 +92,7 @@ class LTXAttentionProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXRotaryPosEmbed(nn.Module):
|
||||
class LTXVideoRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -164,7 +164,7 @@ class LTXRotaryPosEmbed(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXTransformerBlock(nn.Module):
|
||||
class LTXVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
@@ -208,7 +208,7 @@ class LTXTransformerBlock(nn.Module):
|
||||
cross_attention_dim=None,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -221,7 +221,7 @@ class LTXTransformerBlock(nn.Module):
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
processor=LTXVideoAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
@@ -327,7 +327,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.rope = LTXRotaryPosEmbed(
|
||||
self.rope = LTXVideoRotaryPosEmbed(
|
||||
dim=inner_dim,
|
||||
base_num_frames=20,
|
||||
base_height=2048,
|
||||
@@ -339,7 +339,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LTXTransformerBlock(
|
||||
LTXVideoTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
|
||||
Reference in New Issue
Block a user