mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Convert MusicLDM (#4579)
* from audioldm * fix vae * move to new pipeline * copied from audioldm * remove redundant control flow * iterate * fix docstring * finish pipeline * tests: from audioldm2 * iterate * finish fast tests * finish slow integration tests * add docs * remove dtype test * update toctree * "copied from" in conversion (where possible) * Update docs/source/en/api/pipelines/musicldm.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix docstring * make nightly * style * fix dtype test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -224,6 +224,8 @@
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: PaintByExample
|
||||
- local: api/pipelines/paradigms
|
||||
|
||||
57
docs/source/en/api/pipelines/musicldm.md
Normal file
57
docs/source/en/api/pipelines/musicldm.md
Normal file
@@ -0,0 +1,57 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# MusicLDM
|
||||
|
||||
MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
|
||||
MusicLDM takes a text prompt as input and predicts the corresponding music sample.
|
||||
|
||||
Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview) and [AudioLDM](https://huggingface.co/docs/diffusers/api/pipelines/audioldm/overview),
|
||||
MusicLDM is a text-to-music _latent diffusion model (LDM)_ that learns continuous audio representations from [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap)
|
||||
latents.
|
||||
|
||||
MusicLDM is trained on a corpus of 466 hours of music data. Beat-synchronous data augmentation strategies are applied to
|
||||
the music samples, both in the time domain and in the latent space. Using beat-synchronous data augmentation strategies
|
||||
encourages the model to interpolate between the training samples, but stay within the domain of the training data. The
|
||||
result is generated music that is more diverse while staying faithful to the corresponding style.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*In this paper, we present MusicLDM, a state-of-the-art text-to-music model that adapts Stable Diffusion and AudioLDM architectures to the music domain. We achieve this by retraining the contrastive language-audio pretraining model (CLAP) and the Hifi-GAN vocoder, as components of MusicLDM, on a collection of music data samples. Then, we leverage a beat tracking model and propose two different mixup strategies for data augmentation: beat-synchronous audio mixup and beat-synchronous latent mixup, to encourage the model to generate music more diverse while still staying faithful to the corresponding style.*
|
||||
|
||||
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi).
|
||||
|
||||
## Tips
|
||||
|
||||
When constructing a prompt, keep in mind:
|
||||
|
||||
* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno").
|
||||
* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality".
|
||||
|
||||
During inference:
|
||||
|
||||
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
|
||||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
|
||||
* The _length_ of the generated audio sample can be controlled by varying the `audio_length_in_s` argument.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between
|
||||
scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines)
|
||||
section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## MusicLDMPipeline
|
||||
[[autodoc]] MusicLDMPipeline
|
||||
- all
|
||||
- __call__
|
||||
1064
scripts/convert_original_musicldm_to_diffusers.py
Normal file
1064
scripts/convert_original_musicldm_to_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -163,6 +163,7 @@ else:
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
MusicLDMPipeline,
|
||||
PaintByExamplePipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
|
||||
@@ -83,6 +83,7 @@ else:
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from .latent_diffusion import LDMTextToImagePipeline
|
||||
from .musicldm import MusicLDMPipeline
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
|
||||
17
src/diffusers/pipelines/musicldm/__init__.py
Normal file
17
src/diffusers/pipelines/musicldm/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
MusicLDMPipeline,
|
||||
)
|
||||
else:
|
||||
from .pipeline_musicldm import MusicLDMPipeline
|
||||
607
src/diffusers/pipelines/musicldm/pipeline_musicldm.py
Normal file
607
src/diffusers/pipelines/musicldm/pipeline_musicldm.py
Normal file
@@ -0,0 +1,607 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
ClapTextModelWithProjection,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
SpeechT5HifiGan,
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_librosa_available, logging, randn_tensor, replace_example_docstring
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import MusicLDMPipeline
|
||||
>>> import torch
|
||||
>>> import scipy
|
||||
|
||||
>>> repo_id = "cvssp/audioldm-s-full-v2"
|
||||
>>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
|
||||
>>> audio = pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
|
||||
|
||||
>>> # save the audio sample as a .wav file
|
||||
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class MusicLDMPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using MusicLDM.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.ClapModel`]):
|
||||
Frozen text-audio embedding model (`ClapTextModel`), specifically the
|
||||
[laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
|
||||
tokenizer ([`PreTrainedTokenizer`]):
|
||||
A [`~transformers.RobertaTokenizer`] to tokenize text.
|
||||
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
|
||||
Feature extractor to compute mel-spectrograms from audio waveforms.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
||||
Vocoder of class `SpeechT5HifiGan`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: Union[ClapTextModelWithProjection, ClapModel],
|
||||
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
||||
feature_extractor: Optional[ClapFeatureExtractor],
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
vocoder: SpeechT5HifiGan,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device (`torch.device`):
|
||||
torch device
|
||||
num_waveforms_per_prompt (`int`):
|
||||
number of waveforms that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLAP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder.get_text_features(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask.to(device),
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device)
|
||||
|
||||
(
|
||||
bs_embed,
|
||||
seq_len,
|
||||
) = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_input_ids = uncond_input.input_ids.to(device)
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder.get_text_features(
|
||||
uncond_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.text_model.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
|
||||
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
||||
if mel_spectrogram.dim() == 4:
|
||||
mel_spectrogram = mel_spectrogram.squeeze(1)
|
||||
|
||||
waveform = self.vocoder(mel_spectrogram)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
waveform = waveform.cpu().float()
|
||||
return waveform
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm2.pipeline_audioldm2.AudioLDM2Pipeline.score_waveforms
|
||||
def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
|
||||
if not is_librosa_available():
|
||||
logger.info(
|
||||
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
|
||||
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
|
||||
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
|
||||
)
|
||||
return audio
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
|
||||
resampled_audio = librosa.resample(
|
||||
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
||||
)
|
||||
inputs["input_features"] = self.feature_extractor(
|
||||
list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
|
||||
).input_features.type(dtype)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
# compute the audio-text similarity score using the CLAP model
|
||||
logits_per_text = self.text_encoder(**inputs).logits_per_text
|
||||
# sort by the highest matching generations per prompt
|
||||
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
|
||||
audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
|
||||
return audio
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
):
|
||||
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
||||
if audio_length_in_s < min_audio_length_in_s:
|
||||
raise ValueError(
|
||||
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
||||
f"is {audio_length_in_s}."
|
||||
)
|
||||
|
||||
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
||||
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
||||
f"{self.vae_scale_factor}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
self.vocoder.config.model_in_dim // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
audio_length_in_s: Optional[float] = None,
|
||||
num_inference_steps: int = 200,
|
||||
guidance_scale: float = 2.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_waveforms_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
||||
audio_length_in_s (`int`, *optional*, defaults to 10.24):
|
||||
The length of the generated audio sample in seconds.
|
||||
num_inference_steps (`int`, *optional*, defaults to 200):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 2.0):
|
||||
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
||||
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, the text encoding
|
||||
model is a joint text-audio model ([`~transformers.ClapModel`]), and the tokenizer is a
|
||||
`[~transformers.ClapProcessor]`, then automatic scoring will be performed between the generated outputs
|
||||
and the input text. This scoring ranks the generated waveforms based on their cosine similarity to text
|
||||
input in the joint text-audio embedding space.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.AudioPipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
||||
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
||||
model (LDM) output.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.AudioPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.AudioPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is a list with the generated audio.
|
||||
"""
|
||||
# 0. Convert audio input length from seconds to spectrogram height
|
||||
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
||||
|
||||
if audio_length_in_s is None:
|
||||
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
||||
|
||||
height = int(audio_length_in_s / vocoder_upsample_factor)
|
||||
|
||||
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
||||
if height % self.vae_scale_factor != 0:
|
||||
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
||||
logger.info(
|
||||
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
||||
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
||||
f"denoising process."
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_waveforms_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=None,
|
||||
class_labels=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
mel_spectrogram = self.vae.decode(latents).sample
|
||||
else:
|
||||
return AudioPipelineOutput(audios=latents)
|
||||
|
||||
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
||||
|
||||
audio = audio[:, :original_waveform_length]
|
||||
|
||||
# 9. Automatic scoring
|
||||
if num_waveforms_per_prompt > 1 and prompt is not None:
|
||||
audio = self.score_waveforms(
|
||||
text=prompt,
|
||||
audio=audio,
|
||||
num_waveforms_per_prompt=num_waveforms_per_prompt,
|
||||
device=device,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
if output_type == "np":
|
||||
audio = audio.numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (audio,)
|
||||
|
||||
return AudioPipelineOutput(audios=audio)
|
||||
@@ -482,6 +482,21 @@ class LDMTextToImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class MusicLDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PaintByExamplePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/musicldm/__init__.py
Normal file
0
tests/pipelines/musicldm/__init__.py
Normal file
465
tests/pipelines/musicldm/test_musicldm.py
Normal file
465
tests/pipelines/musicldm/test_musicldm.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapAudioConfig,
|
||||
ClapConfig,
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
ClapTextConfig,
|
||||
RobertaTokenizer,
|
||||
SpeechT5HifiGan,
|
||||
SpeechT5HifiGanConfig,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
MusicLDMPipeline,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import is_xformers_available, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu
|
||||
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = MusicLDMPipeline
|
||||
params = TEXT_TO_AUDIO_PARAMS
|
||||
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_waveforms_per_prompt",
|
||||
"generator",
|
||||
"latents",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
"callback",
|
||||
"callback_steps",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=(32, 64),
|
||||
class_embed_type="simple_projection",
|
||||
projection_class_embeddings_input_dim=32,
|
||||
class_embeddings_concat=True,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_branch_config = ClapTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=16,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
audio_branch_config = ClapAudioConfig(
|
||||
spec_size=64,
|
||||
window_size=4,
|
||||
num_mel_bins=64,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
depths=[2, 2],
|
||||
num_attention_heads=[2, 2],
|
||||
num_hidden_layers=2,
|
||||
hidden_size=192,
|
||||
patch_size=2,
|
||||
patch_stride=2,
|
||||
patch_embed_input_channels=4,
|
||||
)
|
||||
text_encoder_config = ClapConfig.from_text_audio_configs(
|
||||
text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=32
|
||||
)
|
||||
text_encoder = ClapModel(text_encoder_config)
|
||||
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
|
||||
feature_extractor = ClapFeatureExtractor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-ClapModel", hop_length=7900
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder_config = SpeechT5HifiGanConfig(
|
||||
model_in_dim=8,
|
||||
sampling_rate=16000,
|
||||
upsample_initial_channel=16,
|
||||
upsample_rates=[2, 2],
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
resblock_kernel_sizes=[3, 7],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
|
||||
normalize_before=False,
|
||||
)
|
||||
|
||||
vocoder = SpeechT5HifiGan(vocoder_config)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"feature_extractor": feature_extractor,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A hammer hitting a wooden surface",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_musicldm_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components()
|
||||
musicldm_pipe = MusicLDMPipeline(**components)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = musicldm_pipe(**inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 256
|
||||
|
||||
audio_slice = audio[:10]
|
||||
expected_slice = np.array(
|
||||
[-0.0027, -0.0036, -0.0037, -0.0020, -0.0035, -0.0019, -0.0037, -0.0020, -0.0038, -0.0019]
|
||||
)
|
||||
|
||||
assert np.abs(audio_slice - expected_slice).max() < 1e-4
|
||||
|
||||
def test_musicldm_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
musicldm_pipe = MusicLDMPipeline(**components)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
|
||||
# forward
|
||||
output = musicldm_pipe(**inputs)
|
||||
audio_1 = output.audios[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
text_inputs = musicldm_pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=musicldm_pipe.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
prompt_embeds = musicldm_pipe.text_encoder.get_text_features(text_inputs)
|
||||
|
||||
inputs["prompt_embeds"] = prompt_embeds
|
||||
|
||||
# forward
|
||||
output = musicldm_pipe(**inputs)
|
||||
audio_2 = output.audios[0]
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
def test_musicldm_negative_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
musicldm_pipe = MusicLDMPipeline(**components)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
negative_prompt = 3 * ["this is a negative prompt"]
|
||||
inputs["negative_prompt"] = negative_prompt
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
|
||||
# forward
|
||||
output = musicldm_pipe(**inputs)
|
||||
audio_1 = output.audios[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
embeds = []
|
||||
for p in [prompt, negative_prompt]:
|
||||
text_inputs = musicldm_pipe.tokenizer(
|
||||
p,
|
||||
padding="max_length",
|
||||
max_length=musicldm_pipe.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
text_embeds = musicldm_pipe.text_encoder.get_text_features(
|
||||
text_inputs,
|
||||
)
|
||||
embeds.append(text_embeds)
|
||||
|
||||
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
|
||||
|
||||
# forward
|
||||
output = musicldm_pipe(**inputs)
|
||||
audio_2 = output.audios[0]
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
def test_musicldm_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
||||
musicldm_pipe = MusicLDMPipeline(**components)
|
||||
musicldm_pipe = musicldm_pipe.to(device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "egg cracking"
|
||||
output = musicldm_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 256
|
||||
|
||||
audio_slice = audio[:10]
|
||||
expected_slice = np.array(
|
||||
[-0.0027, -0.0036, -0.0037, -0.0019, -0.0035, -0.0018, -0.0037, -0.0021, -0.0038, -0.0018]
|
||||
)
|
||||
|
||||
assert np.abs(audio_slice - expected_slice).max() < 1e-4
|
||||
|
||||
def test_musicldm_num_waveforms_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
||||
musicldm_pipe = MusicLDMPipeline(**components)
|
||||
musicldm_pipe = musicldm_pipe.to(device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A hammer hitting a wooden surface"
|
||||
|
||||
# test num_waveforms_per_prompt=1 (default)
|
||||
audios = musicldm_pipe(prompt, num_inference_steps=2).audios
|
||||
|
||||
assert audios.shape == (1, 256)
|
||||
|
||||
# test num_waveforms_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
audios = musicldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
|
||||
|
||||
assert audios.shape == (batch_size, 256)
|
||||
|
||||
# test num_waveforms_per_prompt for single prompt
|
||||
num_waveforms_per_prompt = 2
|
||||
audios = musicldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
|
||||
|
||||
assert audios.shape == (num_waveforms_per_prompt, 256)
|
||||
|
||||
# test num_waveforms_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
audios = musicldm_pipe(
|
||||
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
|
||||
).audios
|
||||
|
||||
assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
|
||||
|
||||
def test_musicldm_audio_length_in_s(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
musicldm_pipe = MusicLDMPipeline(**components)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
vocoder_sampling_rate = musicldm_pipe.vocoder.config.sampling_rate
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = musicldm_pipe(audio_length_in_s=0.016, **inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) / vocoder_sampling_rate == 0.016
|
||||
|
||||
output = musicldm_pipe(audio_length_in_s=0.032, **inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) / vocoder_sampling_rate == 0.032
|
||||
|
||||
def test_musicldm_vocoder_model_in_dim(self):
|
||||
components = self.get_dummy_components()
|
||||
musicldm_pipe = MusicLDMPipeline(**components)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = ["hey"]
|
||||
|
||||
output = musicldm_pipe(prompt, num_inference_steps=1)
|
||||
audio_shape = output.audios.shape
|
||||
assert audio_shape == (1, 256)
|
||||
|
||||
config = musicldm_pipe.vocoder.config
|
||||
config.model_in_dim *= 2
|
||||
musicldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
|
||||
output = musicldm_pipe(prompt, num_inference_steps=1)
|
||||
audio_shape = output.audios.shape
|
||||
# waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
|
||||
assert audio_shape == (1, 256)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# The method component.dtype returns the dtype of the first parameter registered in the model, not the
|
||||
# dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
|
||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||
|
||||
# Without the logit scale parameters, everything is float32
|
||||
model_dtypes.pop("text_encoder")
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||
|
||||
# the CLAP sub-models are float32
|
||||
model_dtypes["clap_text_branch"] = components["text_encoder"].text_model.dtype
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||
|
||||
# Once we send to fp16, all params are in half-precision, including the logit scale
|
||||
pipe.to(torch_dtype=torch.float16)
|
||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class MusicLDMPipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
|
||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||
inputs = {
|
||||
"prompt": "A hammer hitting a wooden surface",
|
||||
"latents": latents,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 2.5,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_musicldm(self):
|
||||
musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 25
|
||||
audio = musicldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[8680:8690]
|
||||
expected_slice = np.array(
|
||||
[-0.1042, -0.1068, -0.1235, -0.1387, -0.1428, -0.136, -0.1213, -0.1097, -0.0967, -0.0945]
|
||||
)
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_musicldm_lms(self):
|
||||
musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
|
||||
musicldm_pipe.scheduler = LMSDiscreteScheduler.from_config(musicldm_pipe.scheduler.config)
|
||||
musicldm_pipe = musicldm_pipe.to(torch_device)
|
||||
musicldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
audio = musicldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[58020:58030]
|
||||
expected_slice = np.array([0.3592, 0.3477, 0.4084, 0.4665, 0.5048, 0.5891, 0.6461, 0.5579, 0.4595, 0.4403])
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
Reference in New Issue
Block a user