mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge pull request #3 from huggingface/ltx-2-vocoder
LTX 2.0 Vocoder Implementation
This commit is contained in:
@@ -10,6 +10,7 @@ from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Video, LTX2VideoTransformer3DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
@@ -61,6 +62,13 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
@@ -99,6 +107,8 @@ LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
@@ -315,6 +325,53 @@ def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) ->
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
}
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
@@ -468,7 +525,13 @@ def main(args):
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
pass
|
||||
if args.vocoder_filename is not None:
|
||||
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.full_pipeline:
|
||||
pass
|
||||
|
||||
173
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
173
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: Tuple[int, ...] = (1, 3, 5),
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
padding=padding_mode
|
||||
)
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=1,
|
||||
padding=padding_mode
|
||||
)
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
xt = conv1(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1024,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
|
||||
upsample_factors: List[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_upsample_layers = len(upsample_kernel_sizes)
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
|
||||
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
|
||||
)
|
||||
|
||||
if self.resnets_per_upsample != len(resnet_dilations):
|
||||
raise ValueError(
|
||||
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
self.resnets = nn.ModuleList()
|
||||
input_channels = hidden_channels
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
output_channels = input_channels // 2
|
||||
self.upsamplers.append(
|
||||
nn.ConvTranspose1d(
|
||||
input_channels, # hidden_channels // (2 ** i)
|
||||
output_channels, # hidden_channels // (2 ** (i + 1))
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
output_channels,
|
||||
kernel_size,
|
||||
dilations=dilations,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
|
||||
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
|
||||
`True`.
|
||||
time_last (`bool`, *optional*, defaults to `False`):
|
||||
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
|
||||
# Ensure that the time/frame dimension is last
|
||||
if not time_last:
|
||||
hidden_states = hidden_states.transpose(2, 3)
|
||||
# Combine channels and frequency (mel bins) dimensions
|
||||
hidden_states = hidden_states.flatten(1, 2)
|
||||
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
start = i * self.resnets_per_upsample
|
||||
end = (i + 1) * self.resnets_per_upsample
|
||||
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
Reference in New Issue
Block a user