mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'ltx-2-transformer' into make-scheduler-consistent
This commit is contained in:
@@ -71,7 +71,10 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {}
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
|
||||
@@ -22,6 +22,9 @@ def convert_state_dict(state_dict: dict) -> dict:
|
||||
if new_key.startswith("decoder."):
|
||||
new_key = new_key[len("decoder.") :]
|
||||
converted[f"decoder.{new_key}"] = value
|
||||
|
||||
converted["latents_mean"] = converted.pop("decoder.per_channel_statistics.mean-of-means")
|
||||
converted["latents_std"] = converted.pop("decoder.per_channel_statistics.std-of-means")
|
||||
return converted
|
||||
|
||||
|
||||
@@ -87,9 +90,21 @@ def main() -> None:
|
||||
latent_width,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device).manual_seed(42)
|
||||
)
|
||||
|
||||
original_out = original_decoder(dummy)
|
||||
|
||||
from diffusers.pipelines.ltx2.pipeline_ltx2 import LTX2Pipeline
|
||||
|
||||
_, a_channels, a_time, a_freq = dummy.shape
|
||||
dummy = dummy.permute(0, 2, 1, 3).reshape(-1, a_time, a_channels * a_freq)
|
||||
dummy = LTX2Pipeline._denormalize_audio_latents(
|
||||
dummy,
|
||||
diffusers_model.latents_mean,
|
||||
diffusers_model.latents_std,
|
||||
)
|
||||
dummy = dummy.view(-1, a_time, a_channels, a_freq).permute(0, 2, 1, 3)
|
||||
diffusers_out = diffusers_model.decode(dummy).sample
|
||||
|
||||
torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4)
|
||||
|
||||
@@ -538,6 +538,7 @@ else:
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LucyEditPipeline",
|
||||
"Lumina2Pipeline",
|
||||
"Lumina2Text2ImgPipeline",
|
||||
@@ -1245,6 +1246,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LucyEditPipeline,
|
||||
Lumina2Pipeline,
|
||||
Lumina2Text2ImgPipeline,
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -27,57 +26,6 @@ from .vae import AutoencoderMixin, DecoderOutput
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
SUPPORTED_CAUSAL_AXES = {"none", "width", "height", "width-compatibility"}
|
||||
|
||||
|
||||
AudioLatentShape = namedtuple(
|
||||
"AudioLatentShape",
|
||||
[
|
||||
"batch",
|
||||
"channels",
|
||||
"frames",
|
||||
"mel_bins",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _resolve_causality_axis(causality_axis: Optional[str] = None) -> Optional[str]:
|
||||
normalized = "none" if causality_axis is None else str(causality_axis).lower()
|
||||
if normalized not in SUPPORTED_CAUSAL_AXES:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported causality_axis '{causality_axis}'. Supported: {sorted(SUPPORTED_CAUSAL_AXES)}"
|
||||
)
|
||||
return None if normalized == "none" else normalized
|
||||
|
||||
|
||||
def make_conv2d(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
padding: Optional[Tuple[int, int, int, int]] = None,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: Optional[str] = None,
|
||||
) -> nn.Module:
|
||||
if causality_axis is not None:
|
||||
return LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis
|
||||
)
|
||||
if padding is None:
|
||||
padding = kernel_size // 2 if isinstance(kernel_size, int) else tuple(k // 2 for k in kernel_size)
|
||||
|
||||
return nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
)
|
||||
|
||||
|
||||
class LTX2AudioCausalConv2d(nn.Module):
|
||||
@@ -147,14 +95,6 @@ class LTX2AudioPixelNorm(nn.Module):
|
||||
return x / rms
|
||||
|
||||
|
||||
def build_normalization_layer(in_channels: int, *, num_groups: int = 32, normtype: str = "group") -> nn.Module:
|
||||
if normtype == "group":
|
||||
return nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if normtype == "pixel":
|
||||
return LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
class LTX2AudioAttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -164,7 +104,12 @@ class LTX2AudioAttnBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
if norm_type == "group":
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
@@ -211,23 +156,49 @@ class LTX2AudioResnetBlock(nn.Module):
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = build_normalization_layer(in_channels, normtype=norm_type)
|
||||
if norm_type == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if causality_axis is not None:
|
||||
self.conv1 = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = build_normalization_layer(out_channels, normtype=norm_type)
|
||||
if norm_type == "group":
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if causality_axis is not None:
|
||||
self.conv2 = LTX2AudioCausalConv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
if causality_axis is not None:
|
||||
self.conv_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.nin_shortcut = make_conv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
if causality_axis is not None:
|
||||
self.nin_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
@@ -254,7 +225,12 @@ class LTX2AudioUpsample(nn.Module):
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
|
||||
if causality_axis is not None:
|
||||
self.conv = LTX2AudioCausalConv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
@@ -273,26 +249,6 @@ class LTX2AudioUpsample(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioPerChannelStatistics(nn.Module):
|
||||
"""
|
||||
Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
"""
|
||||
|
||||
def __init__(self, latent_channels: int = 128) -> None:
|
||||
super().__init__()
|
||||
# Sayak notes: `empty` always causes problems in CI. Should we consider using `torch.ones`?
|
||||
self.register_buffer("std-of-means", torch.empty(latent_channels))
|
||||
self.register_buffer("mean-of-means", torch.empty(latent_channels))
|
||||
|
||||
def denormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
@@ -316,11 +272,9 @@ class LTX2AudioAudioPatchifier:
|
||||
batch, channels, time, freq = audio_latents.shape
|
||||
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, output_shape: AudioLatentShape) -> torch.Tensor:
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
|
||||
batch, time, _ = audio_latents.shape
|
||||
channels = output_shape.channels
|
||||
freq = output_shape.mel_bins
|
||||
return audio_latents.view(batch, time, channels, freq).permute(0, 2, 1, 3)
|
||||
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
@@ -356,9 +310,6 @@ class LTX2AudioDecoder(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
resolved_causality_axis = _resolve_causality_axis(causality_axis)
|
||||
|
||||
self.per_channel_statistics = LTX2AudioPerChannelStatistics(latent_channels=base_channels)
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
@@ -384,116 +335,43 @@ class LTX2AudioDecoder(nn.Module):
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = resolved_causality_axis
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
self.conv_in = make_conv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = self._build_mid_layers(base_block_channels, dropout, mid_block_add_attention)
|
||||
self.up, final_block_channels = self._build_up_path(
|
||||
initial_block_channels=base_block_channels,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=True,
|
||||
)
|
||||
|
||||
self.norm_out = build_normalization_layer(final_block_channels, normtype=self.norm_type)
|
||||
self.conv_out = make_conv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
|
||||
def _adjust_output_shape(self, decoded_output: torch.Tensor, target_shape: AudioLatentShape) -> torch.Tensor:
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_channels = target_shape.channels
|
||||
target_time = target_shape.frames
|
||||
target_freq = target_shape.mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
latent_shape = AudioLatentShape(
|
||||
batch=sample.shape[0],
|
||||
channels=sample.shape[1],
|
||||
frames=sample.shape[2],
|
||||
mel_bins=sample.shape[3],
|
||||
)
|
||||
|
||||
sample_patched = self.patchifier.patchify(sample)
|
||||
sample_denormalized = self.per_channel_statistics.denormalize(sample_patched)
|
||||
sample = self.patchifier.unpatchify(sample_denormalized, latent_shape)
|
||||
|
||||
target_frames = latent_shape.frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_shape = AudioLatentShape(
|
||||
batch=latent_shape.batch,
|
||||
channels=self.out_ch,
|
||||
frames=target_frames,
|
||||
mel_bins=self.mel_bins if self.mel_bins is not None else latent_shape.mel_bins,
|
||||
)
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self._run_mid_layers(hidden_features)
|
||||
hidden_features = self._run_upsampling_path(hidden_features)
|
||||
decoded_output = self._finalize_output(hidden_features)
|
||||
|
||||
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
|
||||
|
||||
return decoded_output
|
||||
|
||||
def _build_mid_layers(self, channels: int, dropout: float, add_attention: bool) -> nn.Module:
|
||||
mid = nn.Module()
|
||||
mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
mid.attn_1 = LTX2AudioAttnBlock(channels, norm_type=self.norm_type) if add_attention else nn.Identity()
|
||||
mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
return mid
|
||||
|
||||
def _build_up_path(
|
||||
self, initial_block_channels: int, dropout: float, resamp_with_conv: bool
|
||||
) -> tuple[nn.ModuleList, int]:
|
||||
up_modules = nn.ModuleList()
|
||||
block_in = initial_block_channels
|
||||
self.up = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
@@ -519,39 +397,89 @@ class LTX2AudioDecoder(nn.Module):
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != 0:
|
||||
stage.upsample = LTX2AudioUpsample(block_in, resamp_with_conv, causality_axis=self.causality_axis)
|
||||
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
up_modules.insert(0, stage)
|
||||
self.up.insert(0, stage)
|
||||
|
||||
return up_modules, block_in
|
||||
final_block_channels = block_in
|
||||
|
||||
def _run_mid_layers(self, features: torch.Tensor) -> torch.Tensor:
|
||||
features = self.mid.block_1(features, temb=None)
|
||||
features = self.mid.attn_1(features)
|
||||
return self.mid.block_2(features, temb=None)
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(
|
||||
num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True
|
||||
)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_, _, frames, mel_bins = sample.shape
|
||||
|
||||
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_channels = self.out_ch
|
||||
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
def _run_upsampling_path(self, features: torch.Tensor) -> torch.Tensor:
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
features = block(features, temb=None)
|
||||
hidden_features = block(hidden_features, temb=None)
|
||||
if stage.attn:
|
||||
features = stage.attn[block_idx](features)
|
||||
hidden_features = stage.attn[block_idx](hidden_features)
|
||||
|
||||
if level != 0 and hasattr(stage, "upsample"):
|
||||
features = stage.upsample(features)
|
||||
hidden_features = stage.upsample(hidden_features)
|
||||
|
||||
return features
|
||||
|
||||
def _finalize_output(self, features: torch.Tensor) -> torch.Tensor:
|
||||
if self.give_pre_end:
|
||||
return features
|
||||
return hidden_features
|
||||
|
||||
hidden = self.norm_out(features)
|
||||
hidden = self.norm_out(hidden_features)
|
||||
hidden = self.non_linearity(hidden)
|
||||
decoded = self.conv_out(hidden)
|
||||
return torch.tanh(decoded) if self.tanh_out else decoded
|
||||
decoded_output = self.conv_out(hidden)
|
||||
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
|
||||
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_time = target_frames
|
||||
target_freq = target_mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
@@ -583,7 +511,10 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
resolved_causality_axis = _resolve_causality_axis(causality_axis)
|
||||
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
|
||||
if causality_axis not in supported_causality_axes:
|
||||
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
|
||||
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
@@ -596,7 +527,7 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=resolved_causality_axis,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
@@ -605,6 +536,13 @@ class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels, ))
|
||||
latents_mean = torch.ones((base_channels, ))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
|
||||
@@ -1051,6 +1051,7 @@ class LTX2VideoTransformer3DModel(
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
audio_timestep: Optional[torch.LongTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
audio_encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
@@ -1073,8 +1074,7 @@ class LTX2VideoTransformer3DModel(
|
||||
Input patchified audio latents of shape (batch_size, num_audio_tokens, audio_in_channels).
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Input text embeddings of shape TODO.
|
||||
timesteps (`torch.Tensor`):
|
||||
Timestep information of shape (batch_size, num_train_timesteps).
|
||||
TODO for the rest.
|
||||
|
||||
Returns:
|
||||
`AudioVisualModelOutput` or `tuple`:
|
||||
@@ -1097,6 +1097,9 @@ class LTX2VideoTransformer3DModel(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
@@ -1143,7 +1146,7 @@ class LTX2VideoTransformer3DModel(
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
||||
|
||||
temb_audio, audio_embedded_timestep = self.audio_time_embed(
|
||||
timestep.flatten(),
|
||||
audio_timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
@@ -1165,12 +1168,12 @@ class LTX2VideoTransformer3DModel(
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
timestep.flatten(),
|
||||
audio_timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
|
||||
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
|
||||
@@ -288,7 +288,7 @@ else:
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -720,7 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .ltx2 import LTX2Pipeline
|
||||
from .ltx2 import LTX2Pipeline, LTX2ImageToVideoPipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["text_encoder"] = ["LTX2AudioVisualTextEncoder"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .text_encoder import LTX2AudioVisualTextEncoder
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
@@ -496,16 +496,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
@@ -516,6 +506,12 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _denormalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents * latents_std) + latents_mean
|
||||
|
||||
@staticmethod
|
||||
def _pack_audio_latents(
|
||||
latents: torch.Tensor, patch_size: Optional[int] = None, patch_size_t: Optional[int] = None
|
||||
@@ -1038,10 +1034,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
audio_latents = audio_latents.to(self.audio_vae.dtype)
|
||||
# NOTE: currently, unlike the video VAE, we denormalize the audio latents inside the audio VAE decoder's
|
||||
# decode method
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
)
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0]
|
||||
audio = self.vocoder(generated_mel_spectrograms)
|
||||
|
||||
|
||||
1138
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
1138
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user