diff --git a/scripts/convert_ltx2_to_diffusers.py b/scripts/convert_ltx2_to_diffusers.py index 85fa169af3..25a04e7893 100644 --- a/scripts/convert_ltx2_to_diffusers.py +++ b/scripts/convert_ltx2_to_diffusers.py @@ -240,7 +240,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, @@ -275,7 +277,9 @@ def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, A "patch_size_t": 1, "resnet_norm_eps": 1e-6, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + "decoder_spatial_padding_mode": "reflect", "spatial_compression_ratio": 32, "temporal_compression_ratio": 8, }, diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py index 755b92c10a..6e7b4d324f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py @@ -29,8 +29,8 @@ from ..normalization import RMSNorm from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoCausalConv3d -class LTXVideoCausalConv3d(nn.Module): +# Like LTXCausalConv3d, but whether causal inference is performed can be specified at runtime +class LTX2VideoCausalConv3d(nn.Module): def __init__( self, in_channels: int, @@ -39,14 +39,12 @@ class LTXVideoCausalConv3d(nn.Module): stride: Union[int, Tuple[int, int, int]] = 1, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, - padding_mode: str = "zeros", - is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels - self.is_causal = is_causal self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) @@ -63,13 +61,13 @@ class LTXVideoCausalConv3d(nn.Module): dilation=dilation, groups=groups, padding=padding, - padding_mode=padding_mode, + padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: time_kernel_size = self.kernel_size[0] - if self.is_causal: + if causal: pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) else: @@ -81,7 +79,8 @@ class LTXVideoCausalConv3d(nn.Module): return hidden_states -# Like LTXVideoResnetBlock3d, but uses a normal Conv3d instead of a causal Conv3d for the conv_shortcut +# Like LTXVideoResnetBlock3d, but uses new causal Conv3d, normal Conv3d for the conv_shortcut, and the spatial padding +# mode is configurable class LTX2VideoResnetBlock3d(nn.Module): r""" A 3D ResNet block used in the LTX 2.0 audiovisual model. @@ -111,9 +110,9 @@ class LTX2VideoResnetBlock3d(nn.Module): eps: float = 1e-6, elementwise_affine: bool = False, non_linearity: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -122,14 +121,20 @@ class LTX2VideoResnetBlock3d(nn.Module): self.nonlinearity = get_activation(non_linearity) self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) - self.conv1 = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv1 = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) self.dropout = nn.Dropout(dropout) - self.conv2 = LTXVideoCausalConv3d( - in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + self.conv2 = LTX2VideoCausalConv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + spatial_padding_mode=spatial_padding_mode, ) self.norm3 = None @@ -140,9 +145,6 @@ class LTX2VideoResnetBlock3d(nn.Module): self.conv_shortcut = nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1 ) - # self.conv_shortcut = LTXVideoCausalConv3d( - # in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal - # ) self.per_channel_scale1 = None self.per_channel_scale2 = None @@ -155,7 +157,11 @@ class LTX2VideoResnetBlock3d(nn.Module): self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) def forward( - self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + self, + inputs: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: hidden_states = inputs @@ -168,7 +174,7 @@ class LTX2VideoResnetBlock3d(nn.Module): hidden_states = hidden_states * (1 + scale_1) + shift_1 hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.conv1(hidden_states) + hidden_states = self.conv1(hidden_states, causal=causal) if self.per_channel_scale1 is not None: spatial_shape = hidden_states.shape[-2:] @@ -184,7 +190,7 @@ class LTX2VideoResnetBlock3d(nn.Module): hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) + hidden_states = self.conv2(hidden_states, causal=causal) if self.per_channel_scale2 is not None: spatial_shape = hidden_states.shape[-2:] @@ -203,15 +209,14 @@ class LTX2VideoResnetBlock3d(nn.Module): return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoDownsampler3d +# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d class LTXVideoDownsampler3d(nn.Module): def __init__( self, in_channels: int, out_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -220,16 +225,15 @@ class LTXVideoDownsampler3d(nn.Module): out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2]) - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2) residual = ( @@ -241,7 +245,7 @@ class LTXVideoDownsampler3d(nn.Module): residual = residual.unflatten(1, (-1, self.group_size)) residual = residual.mean(dim=2) - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = ( hidden_states.unflatten(4, (-1, self.stride[2])) .unflatten(3, (-1, self.stride[1])) @@ -253,16 +257,15 @@ class LTXVideoDownsampler3d(nn.Module): return hidden_states -# Copied from diffusers.models.autoencoders.autoencoder_kl_ltx.LTXVideoUpsampler3d +# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d class LTXVideoUpsampler3d(nn.Module): def __init__( self, in_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, - is_causal: bool = True, residual: bool = False, upscale_factor: int = 1, - padding_mode: str = "zeros", + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -272,16 +275,15 @@ class LTXVideoUpsampler3d(nn.Module): out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor - self.conv = LTXVideoCausalConv3d( + self.conv = LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - is_causal=is_causal, - padding_mode=padding_mode, + spatial_padding_mode=spatial_padding_mode, ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: bool = True) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape if self.residual: @@ -293,7 +295,7 @@ class LTXVideoUpsampler3d(nn.Module): residual = residual.repeat(1, repeats, 1, 1, 1) residual = residual[:, :, self.stride[0] - 1 :] - hidden_states = self.conv(hidden_states) + hidden_states = self.conv(hidden_states, causal=causal) hidden_states = hidden_states.reshape( batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width ) @@ -342,8 +344,8 @@ class LTX2VideoDownBlock3D(nn.Module): resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, downsample_type: str = "conv", + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -358,7 +360,7 @@ class LTX2VideoDownBlock3D(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -369,30 +371,39 @@ class LTX2VideoDownBlock3D(nn.Module): if downsample_type == "conv": self.downsamplers.append( - LTXVideoCausalConv3d( + LTX2VideoCausalConv3d( in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=(2, 2, 2), - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatial": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(1, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "temporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 1, 1), + spatial_padding_mode=spatial_padding_mode, ) ) elif downsample_type == "spatiotemporal": self.downsamplers.append( LTXVideoDownsampler3d( - in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal + in_channels=in_channels, + out_channels=out_channels, + stride=(2, 2, 2), + spatial_padding_mode=spatial_padding_mode, ) ) @@ -403,18 +414,19 @@ class LTX2VideoDownBlock3D(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXDownBlock3D` class.""" for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) if self.downsamplers is not None: for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) + hidden_states = downsampler(hidden_states, causal=causal) return hidden_states @@ -449,9 +461,9 @@ class LTX2VideoMidBlock3d(nn.Module): dropout: float = 0.0, resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, + spatial_padding_mode: str = "zeros", ) -> None: super().__init__() @@ -468,9 +480,9 @@ class LTX2VideoMidBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -482,6 +494,7 @@ class LTX2VideoMidBlock3d(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: r"""Forward method of the `LTXMidBlock3D` class.""" @@ -497,9 +510,9 @@ class LTX2VideoMidBlock3d(nn.Module): for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -540,11 +553,11 @@ class LTX2VideoUpBlock3d(nn.Module): resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, - is_causal: bool = True, inject_noise: bool = False, timestep_conditioning: bool = False, upsample_residual: bool = False, upscale_factor: int = 1, + spatial_padding_mode: str = "zeros", ): super().__init__() @@ -562,9 +575,9 @@ class LTX2VideoUpBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) self.upsamplers = None @@ -574,9 +587,9 @@ class LTX2VideoUpBlock3d(nn.Module): LTXVideoUpsampler3d( out_channels * upscale_factor, stride=(2, 2, 2), - is_causal=is_causal, residual=upsample_residual, upscale_factor=upscale_factor, + spatial_padding_mode=spatial_padding_mode, ) ] ) @@ -590,9 +603,9 @@ class LTX2VideoUpBlock3d(nn.Module): dropout=dropout, eps=resnet_eps, non_linearity=resnet_act_fn, - is_causal=is_causal, inject_noise=inject_noise, timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) ) self.resnets = nn.ModuleList(resnets) @@ -604,9 +617,10 @@ class LTX2VideoUpBlock3d(nn.Module): hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, + causal: bool = True, ) -> torch.Tensor: if self.conv_in is not None: - hidden_states = self.conv_in(hidden_states, temb, generator) + hidden_states = self.conv_in(hidden_states, temb, generator, causal=causal) if self.time_embedder is not None: temb = self.time_embedder( @@ -620,13 +634,13 @@ class LTX2VideoUpBlock3d(nn.Module): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, causal=causal) for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator) + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator, causal) else: - hidden_states = resnet(hidden_states, temb, generator) + hidden_states = resnet(hidden_states, temb, generator, causal=causal) return hidden_states @@ -682,21 +696,23 @@ class LTX2VideoEncoder3d(nn.Module): patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = True, + spatial_padding_mode: str = "zeros", ): super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.in_channels = in_channels * patch_size**2 + self.is_causal = is_causal output_channel = out_channels - self.conv_in = LTXVideoCausalConv3d( + self.conv_in = LTX2VideoCausalConv3d( in_channels=self.in_channels, out_channels=output_channel, kernel_size=3, stride=1, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # down blocks @@ -713,8 +729,8 @@ class LTX2VideoEncoder3d(nn.Module): num_layers=layers_per_block[i], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, downsample_type=downsample_type[i], + spatial_padding_mode=spatial_padding_mode, ) else: raise ValueError(f"Unknown down block type: {down_block_types[i]}") @@ -726,19 +742,23 @@ class LTX2VideoEncoder3d(nn.Module): in_channels=output_channel, num_layers=layers_per_block[-1], resnet_eps=resnet_norm_eps, - is_causal=is_causal, + spatial_padding_mode=spatial_padding_mode, ) # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=out_channels + 1, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""The forward method of the `LTXVideoEncoder3d` class.""" p = self.patch_size @@ -748,28 +768,29 @@ class LTX2VideoEncoder3d(nn.Module): post_patch_num_frames = num_frames // p_t post_patch_height = height // p post_patch_width = width // p + causal = causal or self.is_causal hidden_states = hidden_states.reshape( batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p ) # Thanks for driving me insane with the weird patching order :( hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) - hidden_states = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states, causal=causal) if torch.is_grad_enabled() and self.gradient_checkpointing: for down_block in self.down_blocks: - hidden_states = self._gradient_checkpointing_func(down_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(down_block, hidden_states, None, None, causal) - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, None, None, causal) else: for down_block in self.down_blocks: - hidden_states = down_block(hidden_states) + hidden_states = down_block(hidden_states, causal=causal) - hidden_states = self.mid_block(hidden_states) + hidden_states = self.mid_block(hidden_states, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) last_channel = hidden_states[:, -1:] last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) @@ -817,17 +838,19 @@ class LTX2VideoDecoder3d(nn.Module): patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, - is_causal: bool = True, + is_causal: bool = False, inject_noise: Tuple[bool, ...] = (False, False, False), timestep_conditioning: bool = False, upsample_residual: Tuple[bool, ...] = (True, True, True), upsample_factor: Tuple[bool, ...] = (2, 2, 2), + spatial_padding_mode: str = "reflect", ) -> None: super().__init__() self.patch_size = patch_size self.patch_size_t = patch_size_t self.out_channels = out_channels * patch_size**2 + self.is_causal = is_causal block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) @@ -837,17 +860,21 @@ class LTX2VideoDecoder3d(nn.Module): upsample_factor = tuple(reversed(upsample_factor)) output_channel = block_out_channels[0] - self.conv_in = LTXVideoCausalConv3d( - in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + self.conv_in = LTX2VideoCausalConv3d( + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) self.mid_block = LTX2VideoMidBlock3d( in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, - is_causal=is_causal, inject_noise=inject_noise[0], timestep_conditioning=timestep_conditioning, + spatial_padding_mode=spatial_padding_mode, ) # up blocks @@ -863,11 +890,11 @@ class LTX2VideoDecoder3d(nn.Module): num_layers=layers_per_block[i + 1], resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], - is_causal=is_causal, inject_noise=inject_noise[i + 1], timestep_conditioning=timestep_conditioning, upsample_residual=upsample_residual[i], upscale_factor=upsample_factor[i], + spatial_padding_mode=spatial_padding_mode, ) self.up_blocks.append(up_block) @@ -875,8 +902,12 @@ class LTX2VideoDecoder3d(nn.Module): # out self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) self.conv_act = nn.SiLU() - self.conv_out = LTXVideoCausalConv3d( - in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + self.conv_out = LTX2VideoCausalConv3d( + in_channels=output_channel, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + spatial_padding_mode=spatial_padding_mode, ) # timestep embedding @@ -890,22 +921,26 @@ class LTX2VideoDecoder3d(nn.Module): self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: - hidden_states = self.conv_in(hidden_states) + def forward( + self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, causal: Optional[bool] = None, + ) -> torch.Tensor: + causal = causal or self.is_causal + + hidden_states = self.conv_in(hidden_states, causal=causal) if self.timestep_scale_multiplier is not None: temb = temb * self.timestep_scale_multiplier if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb, None, causal) for up_block in self.up_blocks: - hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb) + hidden_states = self._gradient_checkpointing_func(up_block, hidden_states, temb, None, causal) else: - hidden_states = self.mid_block(hidden_states, temb) + hidden_states = self.mid_block(hidden_states, temb, causal=causal) for up_block in self.up_blocks: - hidden_states = up_block(hidden_states, temb) + hidden_states = up_block(hidden_states, temb, causal=causal) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) @@ -923,7 +958,7 @@ class LTX2VideoDecoder3d(nn.Module): hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.conv_act(hidden_states) - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, causal=causal) p = self.patch_size p_t = self.patch_size_t @@ -1006,6 +1041,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig scaling_factor: float = 1.0, encoder_causal: bool = True, decoder_causal: bool = True, + encoder_spatial_padding_mode: str = "zeros", + decoder_spatial_padding_mode: str = "reflect", spatial_compression_ratio: int = None, temporal_compression_ratio: int = None, ) -> None: @@ -1023,6 +1060,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, is_causal=encoder_causal, + spatial_padding_mode=encoder_spatial_padding_mode, ) self.decoder = LTX2VideoDecoder3d( in_channels=latent_channels, @@ -1038,6 +1076,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig inject_noise=decoder_inject_noise, upsample_residual=upsample_residual, upsample_factor=upsample_factor, + spatial_padding_mode=decoder_spatial_padding_mode, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) @@ -1120,22 +1159,22 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def _encode(self, x: torch.Tensor) -> torch.Tensor: + def _encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape if self.use_framewise_decoding and num_frames > self.tile_sample_min_num_frames: - return self._temporal_tiled_encode(x) + return self._temporal_tiled_encode(x, causal=causal) if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): - return self.tiled_encode(x) + return self.tiled_encode(x, causal=causal) - enc = self.encoder(x) + enc = self.encoder(x, causal=causal) return enc @apply_forward_hook def encode( - self, x: torch.Tensor, return_dict: bool = True + self, x: torch.Tensor, causal: Optional[bool] = None, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. @@ -1150,10 +1189,10 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: - encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + encoded_slices = [self._encode(x_slice, causal=causal) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: - h = self._encode(x) + h = self._encode(x, causal=causal) posterior = DiagonalGaussianDistribution(h) if not return_dict: @@ -1161,7 +1200,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return AutoencoderKLOutput(latent_dist=posterior) def _decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio @@ -1169,12 +1212,12 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames: - return self._temporal_tiled_decode(z, temb, return_dict=return_dict) + return self._temporal_tiled_decode(z, temb, causal=causal, return_dict=return_dict) if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, temb, return_dict=return_dict) + return self.tiled_decode(z, temb, causal=causal, return_dict=return_dict) - dec = self.decoder(z, temb) + dec = self.decoder(z, temb, causal=causal) if not return_dict: return (dec,) @@ -1183,7 +1226,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig @apply_forward_hook def decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + self, + z: torch.Tensor, + temb: Optional[torch.Tensor] = None, + causal: Optional[bool] = None, + return_dict: bool = True, ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. @@ -1201,13 +1248,13 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig if self.use_slicing and z.shape[0] > 1: if temb is not None: decoded_slices = [ - self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + self._decode(z_slice, t_slice, causal=causal).sample for z_slice, t_slice in (z.split(1), temb.split(1)) ] else: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded_slices = [self._decode(z_slice, causal=causal).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z, temb).sample + decoded = self._decode(z, temb, causal=causal).sample if not return_dict: return (decoded,) @@ -1238,7 +1285,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig ) return b - def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + def tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. Args: @@ -1267,7 +1314,8 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig row = [] for j in range(0, width, self.tile_sample_stride_width): time = self.encoder( - x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width], + causal=causal, ) row.append(time) @@ -1290,7 +1338,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return enc def tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1324,7 +1372,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, height, tile_latent_stride_height): row = [] for j in range(0, width, tile_latent_stride_width): - time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb) + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb, causal=causal + ) row.append(time) rows.append(row) @@ -1349,7 +1399,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return DecoderOutput(sample=dec) - def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + def _temporal_tiled_encode(self, x: torch.Tensor, causal: Optional[bool] = None) -> AutoencoderKLOutput: batch_size, num_channels, num_frames, height, width = x.shape latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 @@ -1361,9 +1411,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, num_frames, self.tile_sample_stride_num_frames): tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width): - tile = self.tiled_encode(tile) + tile = self.tiled_encode(tile, causal=causal) else: - tile = self.encoder(tile) + tile = self.encoder(tile, causal=causal) if i > 0: tile = tile[:, :, 1:, :, :] row.append(tile) @@ -1380,7 +1430,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig return enc def _temporal_tiled_decode( - self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + self, z: torch.Tensor, temb: Optional[torch.Tensor], causal: Optional[bool] = None, return_dict: bool = True ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 @@ -1395,9 +1445,9 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig for i in range(0, num_frames, tile_latent_stride_num_frames): tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height): - decoded = self.tiled_decode(tile, temb, return_dict=True).sample + decoded = self.tiled_decode(tile, temb, causal=causal, return_dict=True).sample else: - decoded = self.decoder(tile, temb) + decoded = self.decoder(tile, temb, causal=causal) if i > 0: decoded = decoded[:, :, :-1, :, :] row.append(decoded) @@ -1422,16 +1472,18 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig sample: torch.Tensor, temb: Optional[torch.Tensor] = None, sample_posterior: bool = False, + encoder_causal: Optional[bool] = None, + decoder_causal: Optional[bool] = None, return_dict: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[torch.Tensor, torch.Tensor]: x = sample - posterior = self.encode(x).latent_dist + posterior = self.encode(x, causal=encoder_causal).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z, temb) + dec = self.decode(z, temb, causal=decoder_causal) if not return_dict: return (dec.sample,) return dec diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py index 703ba54f89..25984d621a 100644 --- a/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_ltx2_video.py @@ -55,7 +55,10 @@ class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unit "patch_size": 1, "patch_size_t": 1, "encoder_causal": True, - "decoder_causal": True, + "decoder_causal": False, + "encoder_spatial_padding_mode": "zeros", + # Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros` + "decoder_spatial_padding_mode": "zeros", } @property