1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Get diffusers implementation on par with official LTX 2.0 video VAE implementation

This commit is contained in:
Daniel Gu
2025-12-19 07:02:38 +01:00
parent 491aae08d8
commit a748975a7c
3 changed files with 174 additions and 115 deletions

View File

@@ -240,7 +240,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
@@ -275,7 +277,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},

View File

@@ -29,8 +29,8 @@ from ..normalization import RMSNorm
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d
class LTXVideoCausalConv3d(nn.Module):
# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime
class LTX2VideoCausalConv3d(nn.Module):
def __init__(
self,
in_channels: int,
@@ -39,14 +39,12 @@ class LTXVideoCausalConv3d(nn.Module):
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
padding_mode: str = "zeros",
is_causal: bool = True,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.is_causal = is_causal
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size)
dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1)
@@ -63,13 +61,13 @@ class LTXVideoCausalConv3d(nn.Module):
dilation=dilation,
groups=groups,
padding=padding,
padding_mode=padding_mode,
padding_mode=spatial_padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
time_kernel_size = self.kernel_size[0]
if self.is_causal:
if causal:
pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1))
hidden_states = torch.concatenate([pad_left, hidden_states], dim=2)
else:
@@ -81,7 +79,8 @@ class LTXVideoCausalConv3d(nn.Module):
return hidden_states
# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut
# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding
# mode is configurable
class LTX2VideoResnetBlock3d(nn.Module):
r"""
A 3D ResNet block used in the LTX 2.0 audiovisual model.
@@ -111,9 +110,9 @@ class LTX2VideoResnetBlock3d(nn.Module):
eps: float = 1e-6,
elementwise_affine: bool = False,
non_linearity: str = "swish",
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -122,14 +121,20 @@ class LTX2VideoResnetBlock3d(nn.Module):
self.nonlinearity = get_activation(non_linearity)
self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.conv1 = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
self.conv1 = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
spatial_padding_mode=spatial_padding_mode,
)
self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine)
self.dropout = nn.Dropout(dropout)
self.conv2 = LTXVideoCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal
self.conv2 = LTX2VideoCausalConv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
spatial_padding_mode=spatial_padding_mode,
)
self.norm3 = None
@@ -140,9 +145,6 @@ class LTX2VideoResnetBlock3d(nn.Module):
self.conv_shortcut = nn.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1
)
# self.conv_shortcut = LTXVideoCausalConv3d(
# in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal
# )
self.per_channel_scale1 = None
self.per_channel_scale2 = None
@@ -155,7 +157,11 @@ class LTX2VideoResnetBlock3d(nn.Module):
self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5)
def forward(
self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None
self,
inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
causal: bool = True,
) -> torch.Tensor:
hidden_states = inputs
@@ -168,7 +174,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
hidden_states = hidden_states * (1 + scale_1) + shift_1
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
if self.per_channel_scale1 is not None:
spatial_shape = hidden_states.shape[-2:]
@@ -184,7 +190,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
if self.per_channel_scale2 is not None:
spatial_shape = hidden_states.shape[-2:]
@@ -203,15 +209,14 @@ class LTX2VideoResnetBlock3d(nn.Module):
return hidden_states
# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
class LTXVideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -220,16 +225,15 @@ class LTXVideoDownsampler3d(nn.Module):
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTXVideoCausalConv3d(
self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
@@ -241,7 +245,7 @@ class LTXVideoDownsampler3d(nn.Module):
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states, causal=causal)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
@@ -253,16 +257,15 @@ class LTXVideoDownsampler3d(nn.Module):
return hidden_states
# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
padding_mode: str = "zeros",
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -272,16 +275,15 @@ class LTXVideoUpsampler3d(nn.Module):
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
self.conv = LTXVideoCausalConv3d(
self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
spatial_padding_mode=spatial_padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.residual:
@@ -293,7 +295,7 @@ class LTXVideoUpsampler3d(nn.Module):
residual = residual.repeat(1, repeats, 1, 1, 1)
residual = residual[:, :, self.stride[0] - 1 :]
hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states, causal=causal)
hidden_states = hidden_states.reshape(
batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width
)
@@ -342,8 +344,8 @@ class LTX2VideoDownBlock3D(nn.Module):
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
downsample_type: str = "conv",
spatial_padding_mode: str = "zeros",
):
super().__init__()
@@ -358,7 +360,7 @@ class LTX2VideoDownBlock3D(nn.Module):
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
spatial_padding_mode=spatial_padding_mode,
)
)
self.resnets = nn.ModuleList(resnets)
@@ -369,30 +371,39 @@ class LTX2VideoDownBlock3D(nn.Module):
if downsample_type == "conv":
self.downsamplers.append(
LTXVideoCausalConv3d(
LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
spatial_padding_mode=spatial_padding_mode,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
spatial_padding_mode=spatial_padding_mode,
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
spatial_padding_mode=spatial_padding_mode,
)
)
@@ -403,18 +414,19 @@ class LTX2VideoDownBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
causal: bool = True,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
else:
hidden_states = resnet(hidden_states, temb, generator)
hidden_states = resnet(hidden_states, temb, generator, causal=causal)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, causal=causal)
return hidden_states
@@ -449,9 +461,9 @@ class LTX2VideoMidBlock3d(nn.Module):
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
spatial_padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -468,9 +480,9 @@ class LTX2VideoMidBlock3d(nn.Module):
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
)
self.resnets = nn.ModuleList(resnets)
@@ -482,6 +494,7 @@ class LTX2VideoMidBlock3d(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
causal: bool = True,
) -> torch.Tensor:
r"""Forward method of the `LTXMidBlock3D` class."""
@@ -497,9 +510,9 @@ class LTX2VideoMidBlock3d(nn.Module):
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
else:
hidden_states = resnet(hidden_states, temb, generator)
hidden_states = resnet(hidden_states, temb, generator, causal=causal)
return hidden_states
@@ -540,11 +553,11 @@ class LTX2VideoUpBlock3d(nn.Module):
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
upscale_factor: int = 1,
spatial_padding_mode: str = "zeros",
):
super().__init__()
@@ -562,9 +575,9 @@ class LTX2VideoUpBlock3d(nn.Module):
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
self.upsamplers = None
@@ -574,9 +587,9 @@ class LTX2VideoUpBlock3d(nn.Module):
LTXVideoUpsampler3d(
out_channels * upscale_factor,
stride=(2, 2, 2),
is_causal=is_causal,
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
]
)
@@ -590,9 +603,9 @@ class LTX2VideoUpBlock3d(nn.Module):
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
)
self.resnets = nn.ModuleList(resnets)
@@ -604,9 +617,10 @@ class LTX2VideoUpBlock3d(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
causal: bool = True,
) -> torch.Tensor:
if self.conv_in is not None:
hidden_states = self.conv_in(hidden_states, temb, generator)
hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal)
if self.time_embedder is not None:
temb = self.time_embedder(
@@ -620,13 +634,13 @@ class LTX2VideoUpBlock3d(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, causal=causal)
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal)
else:
hidden_states = resnet(hidden_states, temb, generator)
hidden_states = resnet(hidden_states, temb, generator, causal=causal)
return hidden_states
@@ -682,21 +696,23 @@ class LTX2VideoEncoder3d(nn.Module):
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = True,
spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.in_channels = in_channels * patch_size**2
self.is_causal = is_causal
output_channel = out_channels
self.conv_in = LTXVideoCausalConv3d(
self.conv_in = LTX2VideoCausalConv3d(
in_channels=self.in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
is_causal=is_causal,
spatial_padding_mode=spatial_padding_mode,
)
# down blocks
@@ -713,8 +729,8 @@ class LTX2VideoEncoder3d(nn.Module):
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
downsample_type=downsample_type[i],
spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
@@ -726,19 +742,23 @@ class LTX2VideoEncoder3d(nn.Module):
in_channels=output_channel,
num_layers=layers_per_block[-1],
resnet_eps=resnet_norm_eps,
is_causal=is_causal,
spatial_padding_mode=spatial_padding_mode,
)
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal
self.conv_out = LTX2VideoCausalConv3d(
in_channels=output_channel,
out_channels=out_channels + 1,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor:
r"""The forward method of the `LTXVideoEncoder3d` class."""
p = self.patch_size
@@ -748,28 +768,29 @@ class LTX2VideoEncoder3d(nn.Module):
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
causal = causal or self.is_causal
hidden_states = hidden_states.reshape(
batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p
)
# Thanks for driving me insane with the weird patching order :(
hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4)
hidden_states = self.conv_in(hidden_states)
hidden_states = self.conv_in(hidden_states, causal=causal)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for down_block in self.down_blocks:
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal)
else:
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states)
hidden_states = down_block(hidden_states, causal=causal)
hidden_states = self.mid_block(hidden_states)
hidden_states = self.mid_block(hidden_states, causal=causal)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = self.conv_out(hidden_states, causal=causal)
last_channel = hidden_states[:, -1:]
last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1)
@@ -817,17 +838,19 @@ class LTX2VideoDecoder3d(nn.Module):
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = True,
is_causal: bool = False,
inject_noise: Tuple[bool, ...] = (False, False, False),
timestep_conditioning: bool = False,
upsample_residual: Tuple[bool, ...] = (True, True, True),
upsample_factor: Tuple[bool, ...] = (2, 2, 2),
spatial_padding_mode: str = "reflect",
) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.out_channels = out_channels * patch_size**2
self.is_causal = is_causal
block_out_channels = tuple(reversed(block_out_channels))
spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling))
@@ -837,17 +860,21 @@ class LTX2VideoDecoder3d(nn.Module):
upsample_factor = tuple(reversed(upsample_factor))
output_channel = block_out_channels[0]
self.conv_in = LTXVideoCausalConv3d(
in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal
self.conv_in = LTX2VideoCausalConv3d(
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
self.mid_block = LTX2VideoMidBlock3d(
in_channels=output_channel,
num_layers=layers_per_block[0],
resnet_eps=resnet_norm_eps,
is_causal=is_causal,
inject_noise=inject_noise[0],
timestep_conditioning=timestep_conditioning,
spatial_padding_mode=spatial_padding_mode,
)
# up blocks
@@ -863,11 +890,11 @@ class LTX2VideoDecoder3d(nn.Module):
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
upscale_factor=upsample_factor[i],
spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks.append(up_block)
@@ -875,8 +902,12 @@ class LTX2VideoDecoder3d(nn.Module):
# out
self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False)
self.conv_act = nn.SiLU()
self.conv_out = LTXVideoCausalConv3d(
in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal
self.conv_out = LTX2VideoCausalConv3d(
in_channels=output_channel,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
spatial_padding_mode=spatial_padding_mode,
)
# timestep embedding
@@ -890,22 +921,26 @@ class LTX2VideoDecoder3d(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
def forward(
self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None,
) -> torch.Tensor:
causal = causal or self.is_causal
hidden_states = self.conv_in(hidden_states, causal=causal)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal)
for up_block in self.up_blocks:
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal)
else:
hidden_states = self.mid_block(hidden_states, temb)
hidden_states = self.mid_block(hidden_states, temb, causal=causal)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states, temb)
hidden_states = up_block(hidden_states, temb, causal=causal)
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
@@ -923,7 +958,7 @@ class LTX2VideoDecoder3d(nn.Module):
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = self.conv_out(hidden_states, causal=causal)
p = self.patch_size
p_t = self.patch_size_t
@@ -1006,6 +1041,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = True,
encoder_spatial_padding_mode: str = "zeros",
decoder_spatial_padding_mode: str = "reflect",
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None:
@@ -1023,6 +1060,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
is_causal=encoder_causal,
spatial_padding_mode=encoder_spatial_padding_mode,
)
self.decoder = LTX2VideoDecoder3d(
in_channels=latent_channels,
@@ -1038,6 +1076,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
inject_noise=decoder_inject_noise,
upsample_residual=upsample_residual,
upsample_factor=upsample_factor,
spatial_padding_mode=decoder_spatial_padding_mode,
)
latents_mean = torch.zeros((latent_channels,), requires_grad=False)
@@ -1120,22 +1159,22 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
def _encode(self, x: torch.Tensor) -> torch.Tensor:
def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames:
return self._temporal_tiled_encode(x)
return self._temporal_tiled_encode(x, causal=causal)
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
return self.tiled_encode(x, causal=causal)
enc = self.encoder(x)
enc = self.encoder(x, causal=causal)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -1150,10 +1189,10 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
h = self._encode(x, causal=causal)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
@@ -1161,7 +1200,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
self,
z: torch.Tensor,
temb: Optional[torch.Tensor] = None,
causal: Optional[bool] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
@@ -1169,12 +1212,12 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
return self._temporal_tiled_decode(z, temb, return_dict=return_dict)
return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict)
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, temb, return_dict=return_dict)
return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict)
dec = self.decoder(z, temb)
dec = self.decoder(z, temb, causal=causal)
if not return_dict:
return (dec,)
@@ -1183,7 +1226,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
@apply_forward_hook
def decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True
self,
z: torch.Tensor,
temb: Optional[torch.Tensor] = None,
causal: Optional[bool] = None,
return_dict: bool = True,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -1201,13 +1248,13 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
if self.use_slicing and z.shape[0] > 1:
if temb is not None:
decoded_slices = [
self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1))
self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1))
]
else:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z, temb).sample
decoded = self._decode(z, temb, causal=causal).sample
if not return_dict:
return (decoded,)
@@ -1238,7 +1285,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
@@ -1267,7 +1314,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
row = []
for j in range(0, width, self.tile_sample_stride_width):
time = self.encoder(
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width],
causal=causal,
)
row.append(time)
@@ -1290,7 +1338,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
return enc
def tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
@@ -1324,7 +1372,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb)
time = self.decoder(
z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal
)
row.append(time)
rows.append(row)
@@ -1349,7 +1399,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
return DecoderOutput(sample=dec)
def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput:
batch_size, num_channels, num_frames, height, width = x.shape
latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
@@ -1361,9 +1411,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
for i in range(0, num_frames, self.tile_sample_stride_num_frames):
tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
tile = self.tiled_encode(tile)
tile = self.tiled_encode(tile, causal=causal)
else:
tile = self.encoder(tile)
tile = self.encoder(tile, causal=causal)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
@@ -1380,7 +1430,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
return enc
def _temporal_tiled_decode(
self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True
self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True
) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
@@ -1395,9 +1445,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
for i in range(0, num_frames, tile_latent_stride_num_frames):
tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
decoded = self.tiled_decode(tile, temb, return_dict=True).sample
decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample
else:
decoded = self.decoder(tile, temb)
decoded = self.decoder(tile, temb, causal=causal)
if i > 0:
decoded = decoded[:, :, :-1, :, :]
row.append(decoded)
@@ -1422,16 +1472,18 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
sample: torch.Tensor,
temb: Optional[torch.Tensor] = None,
sample_posterior: bool = False,
encoder_causal: Optional[bool] = None,
decoder_causal: Optional[bool] = None,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[torch.Tensor, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
posterior = self.encode(x, causal=encoder_causal).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, temb)
dec = self.decode(z, temb, causal=decoder_causal)
if not return_dict:
return (dec.sample,)
return dec

View File

@@ -55,7 +55,10 @@ class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unit
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
# Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
"decoder_spatial_padding_mode": "zeros",
}
@property