1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'ltx-2-transformer' into make-scheduler-consistent

This commit is contained in:
Daniel Gu
2025-12-30 20:25:59 +01:00
9 changed files with 1331 additions and 233 deletions

View File

@@ -71,7 +71,10 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",

View File

@@ -22,6 +22,9 @@ def convert_state_dict(state_dict: dict) -> dict:
if new_key.startswith("decoder."):
new_key = new_key[len("decoder.") :]
converted[f"decoder.{new_key}"] = value
converted["latents_mean"] = converted.pop("decoder.per_channel_statistics.mean-of-means")
converted["latents_std"] = converted.pop("decoder.per_channel_statistics.std-of-means")
return converted
@@ -87,9 +90,21 @@ def main() -> None:
latent_width,
device=device,
dtype=dtype,
generator=torch.Generator(device).manual_seed(42)
)
original_out = original_decoder(dummy)
from diffusers.pipelines.ltx2.pipeline_ltx2 import LTX2Pipeline
_, a_channels, a_time, a_freq = dummy.shape
dummy = dummy.permute(0, 2, 1, 3).reshape(-1, a_time, a_channels * a_freq)
dummy = LTX2Pipeline._denormalize_audio_latents(
dummy,
diffusers_model.latents_mean,
diffusers_model.latents_std,
)
dummy = dummy.view(-1, a_time, a_channels, a_freq).permute(0, 2, 1, 3)
diffusers_out = diffusers_model.decode(dummy).sample
torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4)

View File

@@ -538,6 +538,7 @@ else:
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"LTX2Pipeline",
"LTX2ImageToVideoPipeline",
"LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
@@ -1245,6 +1246,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXLatentUpsamplePipeline,
LTXPipeline,
LTX2Pipeline,
LTX2ImageToVideoPipeline,
LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import Optional, Set, Tuple, Union
import torch
@@ -27,57 +26,6 @@ from .vae import AutoencoderMixin, DecoderOutput
LATENT_DOWNSAMPLE_FACTOR = 4
SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"}
AudioLatentShape = namedtuple(
"AudioLatentShape",
[
"batch",
"channels",
"frames",
"mel_bins",
],
)
def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]:
normalized = "none" if causality_axis is None else str(causality_axis).lower()
if normalized not in SUPPORTED_CAUSAL_AXES:
raise NotImplementedError(
f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}"
)
return None if normalized == "none" else normalized
def make_conv2d(
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
padding: Optional[Tuple[int, int, int, int]] = None,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
causality_axis: Optional[str] = None,
) -> nn.Module:
if causality_axis is not None:
return LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
)
if padding is None:
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
return nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
class LTX2AudioCausalConv2d(nn.Module):
@@ -147,14 +95,6 @@ class LTX2AudioPixelNorm(nn.Module):
return x / rms
def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module:
if normtype == "group":
return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if normtype == "pixel":
return LTX2AudioPixelNorm(dim=1, eps=1e-6)
raise ValueError(f"Invalid normalization type: {normtype}")
class LTX2AudioAttnBlock(nn.Module):
def __init__(
self,
@@ -164,7 +104,12 @@ class LTX2AudioAttnBlock(nn.Module):
super().__init__()
self.in_channels = in_channels
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
@@ -211,23 +156,49 @@ class LTX2AudioResnetBlock(nn.Module):
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
if norm_type == "group":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.non_linearity = nn.SiLU()
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if causality_axis is not None:
self.conv1 = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
if norm_type == "group":
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.dropout = nn.Dropout(dropout)
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if causality_axis is not None:
self.conv2 = LTX2AudioCausalConv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
if causality_axis is not None:
self.conv_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
if causality_axis is not None:
self.nin_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.norm1(x)
@@ -254,7 +225,12 @@ class LTX2AudioUpsample(nn.Module):
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if causality_axis is not None:
self.conv = LTX2AudioCausalConv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
@@ -273,26 +249,6 @@ class LTX2AudioUpsample(nn.Module):
return x
class LTX2AudioPerChannelStatistics(nn.Module):
"""
Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
the entire dataset and stored in model's checkpoint under AudioVAE state_dict
"""
def __init__(self, latent_channels: int = 128) -> None:
super().__init__()
# Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`?
self.register_buffer("std-of-means", torch.empty(latent_channels))
self.register_buffer("mean-of-means", torch.empty(latent_channels))
def denormalize(self, x: torch.Tensor) -> torch.Tensor:
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
class LTX2AudioAudioPatchifier:
"""
Patchifier for spectrogram/audio latents.
@@ -316,11 +272,9 @@ class LTX2AudioAudioPatchifier:
batch, channels, time, freq = audio_latents.shape
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor:
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
batch, time, _ = audio_latents.shape
channels = output_shape.channels
freq = output_shape.mel_bins
return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3)
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
@property
def patch_size(self) -> Tuple[int, int, int]:
@@ -356,9 +310,6 @@ class LTX2AudioDecoder(nn.Module):
) -> None:
super().__init__()
resolved_causality_axis = _resolve_causality_axis(causality_axis)
self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels)
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
@@ -384,116 +335,43 @@ class LTX2AudioDecoder(nn.Module):
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = resolved_causality_axis
self.causality_axis = causality_axis
base_block_channels = base_channels * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
self.conv_in = make_conv2d(
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
self.non_linearity = nn.SiLU()
self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention)
self.up, final_block_channels = self._build_up_path(
initial_block_channels=base_block_channels,
dropout=dropout,
resamp_with_conv=True,
)
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
self.conv_out = make_conv2d(
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor:
_, _, current_time, current_freq = decoded_output.shape
target_channels = target_shape.channels
target_time = target_shape.frames
target_freq = target_shape.mel_bins
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
if time_padding_needed > 0 or freq_padding_needed > 0:
padding = (
0,
max(freq_padding_needed, 0),
0,
max(time_padding_needed, 0),
)
decoded_output = F.pad(decoded_output, padding)
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
def forward(
self,
sample: torch.Tensor,
) -> torch.Tensor:
latent_shape = AudioLatentShape(
batch=sample.shape[0],
channels=sample.shape[1],
frames=sample.shape[2],
mel_bins=sample.shape[3],
)
sample_patched = self.patchifier.patchify(sample)
sample_denormalized = self.per_channel_statistics.denormalize(sample_patched)
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis is not None:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_shape = AudioLatentShape(
batch=latent_shape.batch,
channels=self.out_ch,
frames=target_frames,
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
)
hidden_features = self.conv_in(sample)
hidden_features = self._run_mid_layers(hidden_features)
hidden_features = self._run_upsampling_path(hidden_features)
decoded_output = self._finalize_output(hidden_features)
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
return decoded_output
def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module:
mid = nn.Module()
mid.block_1 = LTX2AudioResnetBlock(
in_channels=channels,
out_channels=channels,
self.conv_in = LTX2AudioCausalConv2d(
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.non_linearity = nn.SiLU()
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity()
mid.block_2 = LTX2AudioResnetBlock(
in_channels=channels,
out_channels=channels,
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
return mid
def _build_up_path(
self, initial_block_channels: int, dropout: float, resamp_with_conv: bool
) -> tuple[nn.ModuleList, int]:
up_modules = nn.ModuleList()
block_in = initial_block_channels
self.up = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
for level in reversed(range(self.num_resolutions)):
@@ -519,39 +397,89 @@ class LTX2AudioDecoder(nn.Module):
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != 0:
stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis)
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
curr_res *= 2
up_modules.insert(0, stage)
self.up.insert(0, stage)
return up_modules, block_in
final_block_channels = block_in
def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor:
features = self.mid.block_1(features, temb=None)
features = self.mid.attn_1(features)
return self.mid.block_2(features, temb=None)
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True
)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
def forward(
self,
sample: torch.Tensor,
) -> torch.Tensor:
_, _, frames, mel_bins = sample.shape
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis is not None:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_channels = self.out_ch
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
hidden_features = self.conv_in(sample)
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor:
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx, block in enumerate(stage.block):
features = block(features, temb=None)
hidden_features = block(hidden_features, temb=None)
if stage.attn:
features = stage.attn[block_idx](features)
hidden_features = stage.attn[block_idx](hidden_features)
if level != 0 and hasattr(stage, "upsample"):
features = stage.upsample(features)
hidden_features = stage.upsample(hidden_features)
return features
def _finalize_output(self, features: torch.Tensor) -> torch.Tensor:
if self.give_pre_end:
return features
return hidden_features
hidden = self.norm_out(features)
hidden = self.norm_out(hidden_features)
hidden = self.non_linearity(hidden)
decoded = self.conv_out(hidden)
return torch.tanh(decoded) if self.tanh_out else decoded
decoded_output = self.conv_out(hidden)
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
_, _, current_time, current_freq = decoded_output.shape
target_time = target_frames
target_freq = target_mel_bins
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
if time_padding_needed > 0 or freq_padding_needed > 0:
padding = (
0,
max(freq_padding_needed, 0),
0,
max(time_padding_needed, 0),
)
decoded_output = F.pad(decoded_output, padding)
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
@@ -583,7 +511,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
) -> None:
super().__init__()
resolved_causality_axis = _resolve_causality_axis(causality_axis)
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
if causality_axis not in supported_causality_axes:
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
self.decoder = LTX2AudioDecoder(
@@ -596,7 +527,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=resolved_causality_axis,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
@@ -605,6 +536,13 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
mel_bins=mel_bins,
)
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
latents_std = torch.zeros((base_channels, ))
latents_mean = torch.ones((base_channels, ))
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
# TODO: calculate programmatically instead of hardcoding
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
# TODO: confirm whether the mel compression ratio below is correct

View File

@@ -1051,6 +1051,7 @@ class LTX2VideoTransformer3DModel(
encoder_hidden_states: torch.Tensor,
audio_encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
audio_timestep: Optional[torch.LongTensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
num_frames: Optional[int] = None,
@@ -1073,8 +1074,7 @@ class LTX2VideoTransformer3DModel(
Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels).
encoder_hidden_states (`torch.Tensor`):
Input text embeddings of shape TODO.
timesteps (`torch.Tensor`):
Timestep information of shape (batch_size, num_train_timesteps).
TODO for the rest.
Returns:
`AudioVisualModelOutput` or `tuple`:
@@ -1097,6 +1097,9 @@ class LTX2VideoTransformer3DModel(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Determine timestep for audio.
audio_timestep = audio_timestep if audio_timestep is not None else timestep
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
@@ -1143,7 +1146,7 @@ class LTX2VideoTransformer3DModel(
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
temb_audio, audio_embedded_timestep = self.audio_time_embed(
timestep.flatten(),
audio_timestep.flatten(),
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
@@ -1165,12 +1168,12 @@ class LTX2VideoTransformer3DModel(
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
timestep.flatten(),
audio_timestep.flatten(),
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)

View File

@@ -288,7 +288,7 @@ else:
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
]
_import_structure["ltx2"] = ["LTX2Pipeline"]
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
@@ -720,7 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .ltx2 import LTX2Pipeline
from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline

View File

@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .text_encoder import LTX2AudioVisualTextEncoder
from .vocoder import LTX2Vocoder

View File

@@ -496,16 +496,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
@staticmethod
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
@@ -516,6 +506,12 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latents = latents * latents_std / scaling_factor + latents_mean
return latents
@staticmethod
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
latents_mean = latents_mean.to(latents.device, latents.dtype)
latents_std = latents_std.to(latents.device, latents.dtype)
return (latents * latents_std) + latents_mean
@staticmethod
def _pack_audio_latents(
latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
@@ -1038,10 +1034,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
audio_latents = audio_latents.to(self.audio_vae.dtype)
# NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's
# decode method
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
)
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
audio = self.vocoder(generated_mel_spectrograms)

File diff suppressed because it is too large Load Diff