mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add AudioLDM 2 (#4549)
* from audioldm * unet down + mid * vae, clap, flan-t5 * start sequence audio mae * iterate on audioldm encoder * finish encoder * finish weight conversion * text pre-processing * gpt2 pre-processing * fix projection model * working * unet equivalence * finish in base * add unet cond * finish unet * finish custom unet * start clean-up * revert base unet changes * refactor pre-processing * tests: from audioldm * fix some tests * more fixes * iterate on tests * make fix copies * harden fast tests * slow integration tests * finish tests * update checkpoint * update copyright * docs * remove outdated method * add docstring * make style * remove decode latents * enable cpu offload * (text_encoder_1, tokenizer_1) -> (text_encoder, tokenizer) * more clean up * more refactor * build pr docs * Update docs/source/en/api/pipelines/audioldm2.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * small clean * tidy conversion * update for large checkpoint * generate -> generate_language_model * full clap model * shrink clap-audio in tests * fix large integration test * fix fast tests * use generation config * make style * update docs * finish docs * finish doc * update tests * fix last test * syntax * finalise tests * refactor projection model in prep for TTS * fix fast tests * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -190,6 +190,8 @@
|
||||
title: Audio Diffusion
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/consistency_models
|
||||
|
||||
116
docs/source/en/api/pipelines/audioldm2.md
Normal file
116
docs/source/en/api/pipelines/audioldm2.md
Normal file
@@ -0,0 +1,116 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AudioLDM 2
|
||||
|
||||
AudioLDM 2 was proposed in [AudioLDM 2: Learning Holistic Audio Generation with Self-supervised Pretraining](https://arxiv.org/abs/2308.05734)
|
||||
by Haohe Liu et al. AudioLDM 2 takes a text prompt as input and predicts the corresponding audio. It can generate
|
||||
text-conditional sound effects, human speech and music.
|
||||
|
||||
Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM 2
|
||||
is a text-to-audio _latent diffusion model (LDM)_ that learns continuous audio representations from text embeddings. Two
|
||||
text encoder models are used to compute the text embeddings from a prompt input: the text-branch of [CLAP](https://huggingface.co/docs/transformers/main/en/model_doc/clap)
|
||||
and the encoder of [Flan-T5](https://huggingface.co/docs/transformers/main/en/model_doc/flan-t5). These text embeddings
|
||||
are then projected to a shared embedding space by an [AudioLDM2ProjectionModel](https://huggingface.co/docs/diffusers/api/pipelines/audioldm2/AudioLDM2ProjectionModel).
|
||||
A [GPT2](https://huggingface.co/docs/transformers/main/en/model_doc/gpt2) _language model (LM)_ is used to auto-regressively
|
||||
predict eight new embedding vectors, conditional on the projected CLAP and Flan-T5 embeddings. The generated embedding
|
||||
vectors and Flan-T5 text embeddings are used as cross-attention conditioning in the LDM. The [UNet](https://huggingface.co/docs/diffusers/api/pipelines/audioldm2/AudioLDM2UNet2DConditionModel)
|
||||
of AudioLDM 2 is unique in the sense that it takes **two** cross-attention embeddings, as opposed to one cross-attention
|
||||
conditioning, as in most other LDMs.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called language of audio (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate new state-of-the-art or competitive performance to previous approaches.*
|
||||
|
||||
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be
|
||||
found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
|
||||
|
||||
## Tips
|
||||
|
||||
### Choosing a checkpoint
|
||||
|
||||
AudioLDM2 comes in three variants. Two of these checkpoints are applicable to the general task of text-to-audio generation. The third checkpoint is trained exclusively on text-to-music generation. See table below for details on the three official checkpoints:
|
||||
|
||||
| Checkpoint | Task | Model Size | Training Data / h |
|
||||
|-----------------------------------------------------------------|---------------|------------|-------------------|
|
||||
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 1.1B | 1150k |
|
||||
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 1.1B | 665k |
|
||||
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 1.5B | 1150k |
|
||||
|
||||
### Constructing a prompt
|
||||
|
||||
* Descriptive prompt inputs work best: use adjectives to describe the sound (e.g. "high quality" or "clear") and make the prompt context specific (e.g. "water stream in a forest" instead of "stream").
|
||||
* It's best to use general terms like "cat" or "dog" instead of specific names or abstract objects the model may not be familiar with.
|
||||
* Using a **negative prompt** can significantly improve the quality of the generated waveform, by guiding the generation away from terms that correspond to poor quality audio. Try using a negative prompt of "Low quality."
|
||||
|
||||
### Controlling inference
|
||||
|
||||
* The _quality_ of the predicted audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
|
||||
* The _length_ of the predicted audio sample can be controlled by varying the `audio_length_in_s` argument.
|
||||
|
||||
### Evaluating generated waveforms:
|
||||
|
||||
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation
|
||||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
|
||||
|
||||
The following example demonstrates how to construct good music generation using the aforementioned tips:
|
||||
|
||||
```python
|
||||
import scipy
|
||||
import torch
|
||||
from diffusers import AudioLDM2Pipeline
|
||||
|
||||
# load the best weights for music generation
|
||||
repo_id = "cvssp/audioldm2-music"
|
||||
pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# define the prompts
|
||||
prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
|
||||
negative_prompt = "Low quality."
|
||||
|
||||
# set the seed
|
||||
generator = torch.Generator("cuda").manual_seed(0)
|
||||
|
||||
# run the generation
|
||||
audio = pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=200,
|
||||
audio_length_in_s=10.0,
|
||||
num_waveforms_per_prompt=3,
|
||||
).audios
|
||||
|
||||
# save the best audio sample (index 0) as a .wav file
|
||||
scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between
|
||||
scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines)
|
||||
section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AudioLDM2Pipeline
|
||||
[[autodoc]] AudioLDM2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## AudioLDM2ProjectionModel
|
||||
[[autodoc]] AudioLDM2ProjectionModel
|
||||
- forward
|
||||
|
||||
## AudioLDM2UNet2DConditionModel
|
||||
[[autodoc]] AudioLDM2UNet2DConditionModel
|
||||
- forward
|
||||
1140
scripts/convert_original_audioldm2_to_diffusers.py
Normal file
1140
scripts/convert_original_audioldm2_to_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -133,6 +133,9 @@ else:
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
|
||||
@@ -88,6 +88,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
@@ -181,6 +182,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
|
||||
@@ -46,6 +46,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
|
||||
from .audioldm import AudioLDMPipeline
|
||||
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .controlnet import (
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
|
||||
19
src/diffusers/pipelines/audioldm2/__init__.py
Normal file
19
src/diffusers/pipelines/audioldm2/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from ...utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDMPipeline,
|
||||
)
|
||||
else:
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
from .pipeline_audioldm2 import AudioLDM2Pipeline
|
||||
1496
src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
Normal file
1496
src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
Normal file
File diff suppressed because it is too large
Load Diff
955
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
Normal file
955
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
Normal file
@@ -0,0 +1,955 @@
|
||||
# Copyright 2023 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
GPT2Model,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
SpeechT5HifiGan,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_librosa_available,
|
||||
logging,
|
||||
randn_tensor,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
import librosa
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import AudioLDM2Pipeline
|
||||
>>> import torch
|
||||
>>> import scipy
|
||||
|
||||
>>> repo_id = "cvssp/audioldm2"
|
||||
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "Techno music with a strong, upbeat tempo and high melodic riffs"
|
||||
>>> audio = pipe(prompt, num_inference_steps=200, audio_length_in_s=10.0).audios[0]
|
||||
|
||||
>>> # save the audio sample as a .wav file
|
||||
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
inputs_embeds,
|
||||
attention_mask=None,
|
||||
past_key_values=None,
|
||||
**kwargs,
|
||||
):
|
||||
if past_key_values is not None:
|
||||
# only last token for inputs_embeds if past is defined in kwargs
|
||||
inputs_embeds = inputs_embeds[:, -1:]
|
||||
|
||||
return {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
|
||||
|
||||
class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using AudioLDM2.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`~transformers.ClapModel`]):
|
||||
First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
|
||||
[CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
|
||||
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
|
||||
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
|
||||
rank generated waveforms against the text prompt by computing similarity scores.
|
||||
text_encoder_2 ([`~transformers.T5EncoderModel`]):
|
||||
Second frozen text-encoder. AudioLDM2 uses the encoder of
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
|
||||
projection_model ([`AudioLDM2ProjectionModel`]):
|
||||
A trained model used to linearly project the hidden-states from the first and second text encoder models
|
||||
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
|
||||
concatenated to give the input to the language model.
|
||||
language_model ([`~transformers.GPT2Model`]):
|
||||
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
|
||||
outputs from the two text encoders.
|
||||
tokenizer ([`~transformers.RobertaTokenizer`]):
|
||||
Tokenizer to tokenize text for the first frozen text-encoder.
|
||||
tokenizer_2 ([`~transformers.T5Tokenizer`]):
|
||||
Tokenizer to tokenize text for the second frozen text-encoder.
|
||||
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
|
||||
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
|
||||
unet ([`UNet2DConditionModel`]):
|
||||
A `UNet2DConditionModel` to denoise the encoded audio latents.
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
vocoder ([`~transformers.SpeechT5HifiGan`]):
|
||||
Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: ClapModel,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
projection_model: AudioLDM2ProjectionModel,
|
||||
language_model: GPT2Model,
|
||||
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
||||
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
|
||||
feature_extractor: ClapFeatureExtractor,
|
||||
unet: AudioLDM2UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
vocoder: SpeechT5HifiGan,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
projection_model=projection_model,
|
||||
language_model=language_model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
feature_extractor=feature_extractor,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_model_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
||||
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
||||
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
||||
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
||||
"""
|
||||
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
from accelerate import cpu_offload_with_hook
|
||||
else:
|
||||
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
|
||||
model_sequence = [
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
self.projection_model,
|
||||
self.language_model,
|
||||
self.unet,
|
||||
self.vae,
|
||||
]
|
||||
|
||||
hook = None
|
||||
for cpu_offloaded_model in model_sequence:
|
||||
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
|
||||
|
||||
# We'll offload the last model manually.
|
||||
self.final_offload_hook = hook
|
||||
|
||||
def generate_language_model(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor = None,
|
||||
max_new_tokens: int = 8,
|
||||
**model_kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
|
||||
|
||||
Parameters:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence used as a prompt for the generation.
|
||||
max_new_tokens (`int`):
|
||||
Number of new tokens to generate.
|
||||
model_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
|
||||
function of the model.
|
||||
|
||||
Return:
|
||||
`inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
# forward pass to get next hidden states
|
||||
output = self.language_model(**model_inputs, return_dict=True)
|
||||
|
||||
next_hidden_states = output.last_hidden_state
|
||||
|
||||
# Update the model input
|
||||
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
|
||||
|
||||
# Update generated hidden states, model inputs, and length for next step
|
||||
model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
|
||||
|
||||
return inputs_embeds[:, -max_new_tokens:, :]
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device (`torch.device`):
|
||||
torch device
|
||||
num_waveforms_per_prompt (`int`):
|
||||
number of waveforms that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the audio generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
|
||||
prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
||||
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
||||
be computed from `prompt` input argument.
|
||||
negative_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
||||
mask will be computed from `negative_prompt` input argument.
|
||||
max_new_tokens (`int`, *optional*, defaults to None):
|
||||
The number of new tokens to generate with the GPT2 language model.
|
||||
Returns:
|
||||
prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings from the Flan T5 model.
|
||||
attention_mask (`torch.LongTensor`):
|
||||
Attention mask to be applied to the `prompt_embeds`.
|
||||
generated_prompt_embeds (`torch.FloatTensor`):
|
||||
Text embeddings generated from the GPT2 langauge model.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import AudioLDM2Pipeline
|
||||
|
||||
>>> repo_id = "cvssp/audioldm2"
|
||||
>>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> # Get text embedding vectors
|
||||
>>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
|
||||
... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
|
||||
... device="cuda",
|
||||
... do_classifier_free_guidance=True,
|
||||
... )
|
||||
|
||||
>>> # Pass text embeddings to pipeline for text-conditional audio generation
|
||||
>>> audio = pipe(
|
||||
... prompt_embeds=prompt_embeds,
|
||||
... attention_mask=attention_mask,
|
||||
... generated_prompt_embeds=generated_prompt_embeds,
|
||||
... num_inference_steps=200,
|
||||
... audio_length_in_s=10.0,
|
||||
... ).audios[0]
|
||||
```"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds_list = []
|
||||
attention_mask_list = []
|
||||
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
f"The following part of your input was truncated because {text_encoder.config.model_type} can "
|
||||
f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_input_ids = text_input_ids.to(device)
|
||||
attention_mask = attention_mask.to(device)
|
||||
|
||||
if text_encoder.config.model_type == "clap":
|
||||
prompt_embeds = text_encoder.get_text_features(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
prompt_embeds = prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
attention_mask = attention_mask.new_ones((batch_size, 1))
|
||||
else:
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
attention_mask_list.append(attention_mask)
|
||||
|
||||
projection_output = self.projection_model(
|
||||
hidden_states=prompt_embeds_list[0],
|
||||
hidden_states_1=prompt_embeds_list[1],
|
||||
attention_mask=attention_mask_list[0],
|
||||
attention_mask_1=attention_mask_list[1],
|
||||
)
|
||||
projected_prompt_embeds = projection_output.hidden_states
|
||||
projected_attention_mask = projection_output.attention_mask
|
||||
|
||||
generated_prompt_embeds = self.generate_language_model(
|
||||
projected_prompt_embeds,
|
||||
attention_mask=projected_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
attention_mask = (
|
||||
attention_mask.to(device=device)
|
||||
if attention_mask is not None
|
||||
else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
||||
)
|
||||
generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, hidden_size = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
|
||||
|
||||
# duplicate attention mask for each generation per prompt
|
||||
attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
|
||||
attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
|
||||
# duplicate generated embeddings for each generation per prompt, using mps friendly method
|
||||
generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
generated_prompt_embeds = generated_prompt_embeds.view(
|
||||
bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
negative_attention_mask_list = []
|
||||
max_length = prompt_embeds.shape[1]
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
uncond_input = tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length
|
||||
if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
|
||||
else max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_input_ids = uncond_input.input_ids.to(device)
|
||||
negative_attention_mask = uncond_input.attention_mask.to(device)
|
||||
|
||||
if text_encoder.config.model_type == "clap":
|
||||
negative_prompt_embeds = text_encoder.get_text_features(
|
||||
uncond_input_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
)
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
|
||||
else:
|
||||
negative_prompt_embeds = text_encoder(
|
||||
uncond_input_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
negative_attention_mask_list.append(negative_attention_mask)
|
||||
|
||||
projection_output = self.projection_model(
|
||||
hidden_states=negative_prompt_embeds_list[0],
|
||||
hidden_states_1=negative_prompt_embeds_list[1],
|
||||
attention_mask=negative_attention_mask_list[0],
|
||||
attention_mask_1=negative_attention_mask_list[1],
|
||||
)
|
||||
negative_projected_prompt_embeds = projection_output.hidden_states
|
||||
negative_projected_attention_mask = projection_output.attention_mask
|
||||
|
||||
negative_generated_prompt_embeds = self.generate_language_model(
|
||||
negative_projected_prompt_embeds,
|
||||
attention_mask=negative_projected_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
|
||||
negative_attention_mask = (
|
||||
negative_attention_mask.to(device=device)
|
||||
if negative_attention_mask is not None
|
||||
else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
|
||||
)
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
|
||||
dtype=self.language_model.dtype, device=device
|
||||
)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
|
||||
|
||||
# duplicate unconditional attention mask for each generation per prompt
|
||||
negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
|
||||
negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
|
||||
|
||||
# duplicate unconditional generated embeddings for each generation per prompt
|
||||
seq_len = negative_generated_prompt_embeds.shape[1]
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
|
||||
negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
|
||||
batch_size * num_waveforms_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
attention_mask = torch.cat([negative_attention_mask, attention_mask])
|
||||
generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
|
||||
|
||||
return prompt_embeds, attention_mask, generated_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
|
||||
def mel_spectrogram_to_waveform(self, mel_spectrogram):
|
||||
if mel_spectrogram.dim() == 4:
|
||||
mel_spectrogram = mel_spectrogram.squeeze(1)
|
||||
|
||||
waveform = self.vocoder(mel_spectrogram)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
waveform = waveform.cpu().float()
|
||||
return waveform
|
||||
|
||||
def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
|
||||
if not is_librosa_available():
|
||||
logger.info(
|
||||
"Automatic scoring of the generated audio waveforms against the input prompt text requires the "
|
||||
"`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
|
||||
"generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
|
||||
)
|
||||
return audio
|
||||
inputs = self.tokenizer(text, return_tensors="pt", padding=True)
|
||||
resampled_audio = librosa.resample(
|
||||
audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
||||
)
|
||||
inputs["input_features"] = self.feature_extractor(
|
||||
list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
|
||||
).input_features.type(dtype)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
# compute the audio-text similarity score using the CLAP model
|
||||
logits_per_text = self.text_encoder(**inputs).logits_per_text
|
||||
# sort by the highest matching generations per prompt
|
||||
indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
|
||||
audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
|
||||
return audio
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
generated_prompt_embeds=None,
|
||||
negative_generated_prompt_embeds=None,
|
||||
attention_mask=None,
|
||||
negative_attention_mask=None,
|
||||
):
|
||||
min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
|
||||
if audio_length_in_s < min_audio_length_in_s:
|
||||
raise ValueError(
|
||||
f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
|
||||
f"is {audio_length_in_s}."
|
||||
)
|
||||
|
||||
if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
|
||||
raise ValueError(
|
||||
f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
|
||||
f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
|
||||
f"{self.vae_scale_factor}."
|
||||
)
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
|
||||
raise ValueError(
|
||||
"Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
|
||||
"`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
|
||||
"both arguments are specified"
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
|
||||
raise ValueError(
|
||||
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
||||
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
|
||||
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
|
||||
f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
|
||||
f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
|
||||
)
|
||||
if (
|
||||
negative_attention_mask is not None
|
||||
and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
|
||||
):
|
||||
raise ValueError(
|
||||
"`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
|
||||
f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
self.vocoder.config.model_in_dim // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
audio_length_in_s: Optional[float] = None,
|
||||
num_inference_steps: int = 200,
|
||||
guidance_scale: float = 3.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_waveforms_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
negative_attention_mask: Optional[torch.LongTensor] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
|
||||
audio_length_in_s (`int`, *optional*, defaults to 10.24):
|
||||
The length of the generated audio sample in seconds.
|
||||
num_inference_steps (`int`, *optional*, defaults to 200):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
A higher guidance scale value encourages the model to generate audio that is closely linked to the text
|
||||
`prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
|
||||
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
||||
num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
|
||||
scoring is performed between the generated outputs and the text prompt. This scoring ranks the
|
||||
generated waveforms based on their cosine similarity with the text input in the joint text-audio
|
||||
embedding space.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
||||
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
||||
generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
|
||||
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
|
||||
argument.
|
||||
negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
|
||||
inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
|
||||
`negative_prompt` input argument.
|
||||
attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
|
||||
be computed from `prompt` input argument.
|
||||
negative_attention_mask (`torch.LongTensor`, *optional*):
|
||||
Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
|
||||
mask will be computed from `negative_prompt` input argument.
|
||||
max_new_tokens (`int`, *optional*, defaults to None):
|
||||
Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
|
||||
be taken from the config of the model.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that calls every `callback_steps` steps during inference. The function is called with the
|
||||
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function is called. If not specified, the callback is called at
|
||||
every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
|
||||
`"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
|
||||
model (LDM) output.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
|
||||
otherwise a `tuple` is returned where the first element is a list with the generated audio.
|
||||
"""
|
||||
# 0. Convert audio input length from seconds to spectrogram height
|
||||
vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
|
||||
|
||||
if audio_length_in_s is None:
|
||||
audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
|
||||
|
||||
height = int(audio_length_in_s / vocoder_upsample_factor)
|
||||
|
||||
original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
|
||||
if height % self.vae_scale_factor != 0:
|
||||
height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
|
||||
logger.info(
|
||||
f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
|
||||
f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
|
||||
f"denoising process."
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
audio_length_in_s,
|
||||
vocoder_upsample_factor,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
generated_prompt_embeds,
|
||||
negative_generated_prompt_embeds,
|
||||
attention_mask,
|
||||
negative_attention_mask,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_waveforms_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
generated_prompt_embeds=generated_prompt_embeds,
|
||||
negative_generated_prompt_embeds=negative_generated_prompt_embeds,
|
||||
attention_mask=attention_mask,
|
||||
negative_attention_mask=negative_attention_mask,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_waveforms_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare extra step kwargs
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=generated_prompt_embeds,
|
||||
encoder_hidden_states_1=prompt_embeds,
|
||||
encoder_attention_mask_1=attention_mask,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
if not output_type == "latent":
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
mel_spectrogram = self.vae.decode(latents).sample
|
||||
else:
|
||||
return AudioPipelineOutput(audios=latents)
|
||||
|
||||
audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
|
||||
|
||||
audio = audio[:, :original_waveform_length]
|
||||
|
||||
# 9. Automatic scoring
|
||||
if num_waveforms_per_prompt > 1 and prompt is not None:
|
||||
audio = self.score_waveforms(
|
||||
text=prompt,
|
||||
audio=audio,
|
||||
num_waveforms_per_prompt=num_waveforms_per_prompt,
|
||||
device=device,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
if output_type == "np":
|
||||
audio = audio.numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (audio,)
|
||||
|
||||
return AudioPipelineOutput(audios=audio)
|
||||
@@ -32,6 +32,51 @@ class AltDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDM2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDM2ProjectionModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDM2UNet2DConditionModel(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AudioLDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/audioldm2/__init__.py
Normal file
0
tests/pipelines/audioldm2/__init__.py
Normal file
571
tests/pipelines/audioldm2/test_audioldm2.py
Normal file
571
tests/pipelines/audioldm2/test_audioldm2.py
Normal file
@@ -0,0 +1,571 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
ClapAudioConfig,
|
||||
ClapConfig,
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
ClapTextConfig,
|
||||
GPT2Config,
|
||||
GPT2Model,
|
||||
RobertaTokenizer,
|
||||
SpeechT5HifiGan,
|
||||
SpeechT5HifiGanConfig,
|
||||
T5Config,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AudioLDM2Pipeline,
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import is_xformers_available, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = AudioLDM2Pipeline
|
||||
params = TEXT_TO_AUDIO_PARAMS
|
||||
batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"num_waveforms_per_prompt",
|
||||
"generator",
|
||||
"latents",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
"callback",
|
||||
"callback_steps",
|
||||
]
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = AudioLDM2UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=([None, 16, 32], [None, 16, 32]),
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_branch_config = ClapTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=16,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=2,
|
||||
num_hidden_layers=2,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
projection_dim=16,
|
||||
)
|
||||
audio_branch_config = ClapAudioConfig(
|
||||
spec_size=64,
|
||||
window_size=4,
|
||||
num_mel_bins=64,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
depths=[2, 2],
|
||||
num_attention_heads=[2, 2],
|
||||
num_hidden_layers=2,
|
||||
hidden_size=192,
|
||||
projection_dim=16,
|
||||
patch_size=2,
|
||||
patch_stride=2,
|
||||
patch_embed_input_channels=4,
|
||||
)
|
||||
text_encoder_config = ClapConfig.from_text_audio_configs(
|
||||
text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
|
||||
)
|
||||
text_encoder = ClapModel(text_encoder_config)
|
||||
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
|
||||
feature_extractor = ClapFeatureExtractor.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-ClapModel", hop_length=7900
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2_config = T5Config(
|
||||
vocab_size=32100,
|
||||
d_model=32,
|
||||
d_ff=37,
|
||||
d_kv=8,
|
||||
num_heads=2,
|
||||
num_layers=2,
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel(text_encoder_2_config)
|
||||
tokenizer_2 = T5Tokenizer.from_pretrained("hf-internal-testing/tiny-random-T5Model", model_max_length=77)
|
||||
|
||||
torch.manual_seed(0)
|
||||
language_model_config = GPT2Config(
|
||||
n_embd=16,
|
||||
n_head=2,
|
||||
n_layer=2,
|
||||
vocab_size=1000,
|
||||
n_ctx=99,
|
||||
n_positions=99,
|
||||
)
|
||||
language_model = GPT2Model(language_model_config)
|
||||
language_model.config.max_new_tokens = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
projection_model = AudioLDM2ProjectionModel(text_encoder_dim=16, text_encoder_1_dim=32, langauge_model_dim=16)
|
||||
|
||||
vocoder_config = SpeechT5HifiGanConfig(
|
||||
model_in_dim=8,
|
||||
sampling_rate=16000,
|
||||
upsample_initial_channel=16,
|
||||
upsample_rates=[2, 2],
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
resblock_kernel_sizes=[3, 7],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
|
||||
normalize_before=False,
|
||||
)
|
||||
|
||||
vocoder = SpeechT5HifiGan(vocoder_config)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"feature_extractor": feature_extractor,
|
||||
"language_model": language_model,
|
||||
"projection_model": projection_model,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A hammer hitting a wooden surface",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_audioldm2_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 256
|
||||
|
||||
audio_slice = audio[:10]
|
||||
expected_slice = np.array(
|
||||
[0.0025, 0.0018, 0.0018, -0.0023, -0.0026, -0.0020, -0.0026, -0.0021, -0.0027, -0.0020]
|
||||
)
|
||||
|
||||
assert np.abs(audio_slice - expected_slice).max() < 1e-4
|
||||
|
||||
def test_audioldm2_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_1 = output.audios[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
text_inputs = audioldm_pipe.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=audioldm_pipe.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
clap_prompt_embeds = audioldm_pipe.text_encoder.get_text_features(text_inputs)
|
||||
clap_prompt_embeds = clap_prompt_embeds[:, None, :]
|
||||
|
||||
text_inputs = audioldm_pipe.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=True,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
t5_prompt_embeds = audioldm_pipe.text_encoder_2(
|
||||
text_inputs,
|
||||
)
|
||||
t5_prompt_embeds = t5_prompt_embeds[0]
|
||||
|
||||
projection_embeds = audioldm_pipe.projection_model(clap_prompt_embeds, t5_prompt_embeds)[0]
|
||||
generated_prompt_embeds = audioldm_pipe.generate_language_model(projection_embeds, max_new_tokens=8)
|
||||
|
||||
inputs["prompt_embeds"] = t5_prompt_embeds
|
||||
inputs["generated_prompt_embeds"] = generated_prompt_embeds
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_2 = output.audios[0]
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
def test_audioldm2_negative_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
negative_prompt = 3 * ["this is a negative prompt"]
|
||||
inputs["negative_prompt"] = negative_prompt
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_1 = output.audios[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
embeds = []
|
||||
generated_embeds = []
|
||||
for p in [prompt, negative_prompt]:
|
||||
text_inputs = audioldm_pipe.tokenizer(
|
||||
p,
|
||||
padding="max_length",
|
||||
max_length=audioldm_pipe.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
clap_prompt_embeds = audioldm_pipe.text_encoder.get_text_features(text_inputs)
|
||||
clap_prompt_embeds = clap_prompt_embeds[:, None, :]
|
||||
|
||||
text_inputs = audioldm_pipe.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=True if len(embeds) == 0 else embeds[0].shape[1],
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_inputs = text_inputs["input_ids"].to(torch_device)
|
||||
|
||||
t5_prompt_embeds = audioldm_pipe.text_encoder_2(
|
||||
text_inputs,
|
||||
)
|
||||
t5_prompt_embeds = t5_prompt_embeds[0]
|
||||
|
||||
projection_embeds = audioldm_pipe.projection_model(clap_prompt_embeds, t5_prompt_embeds)[0]
|
||||
generated_prompt_embeds = audioldm_pipe.generate_language_model(projection_embeds, max_new_tokens=8)
|
||||
|
||||
embeds.append(t5_prompt_embeds)
|
||||
generated_embeds.append(generated_prompt_embeds)
|
||||
|
||||
inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
|
||||
inputs["generated_prompt_embeds"], inputs["negative_generated_prompt_embeds"] = generated_embeds
|
||||
|
||||
# forward
|
||||
output = audioldm_pipe(**inputs)
|
||||
audio_2 = output.audios[0]
|
||||
|
||||
assert np.abs(audio_1 - audio_2).max() < 1e-2
|
||||
|
||||
def test_audioldm2_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "egg cracking"
|
||||
output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 256
|
||||
|
||||
audio_slice = audio[:10]
|
||||
expected_slice = np.array(
|
||||
[0.0025, 0.0018, 0.0018, -0.0023, -0.0026, -0.0020, -0.0026, -0.0021, -0.0027, -0.0020]
|
||||
)
|
||||
|
||||
assert np.abs(audio_slice - expected_slice).max() < 1e-4
|
||||
|
||||
def test_audioldm2_num_waveforms_per_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A hammer hitting a wooden surface"
|
||||
|
||||
# test num_waveforms_per_prompt=1 (default)
|
||||
audios = audioldm_pipe(prompt, num_inference_steps=2).audios
|
||||
|
||||
assert audios.shape == (1, 256)
|
||||
|
||||
# test num_waveforms_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
|
||||
|
||||
assert audios.shape == (batch_size, 256)
|
||||
|
||||
# test num_waveforms_per_prompt for single prompt
|
||||
num_waveforms_per_prompt = 2
|
||||
audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
|
||||
|
||||
assert audios.shape == (num_waveforms_per_prompt, 256)
|
||||
|
||||
# test num_waveforms_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
audios = audioldm_pipe(
|
||||
[prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
|
||||
).audios
|
||||
|
||||
assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
|
||||
|
||||
def test_audioldm2_audio_length_in_s(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) / vocoder_sampling_rate == 0.016
|
||||
|
||||
output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
|
||||
audio = output.audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) / vocoder_sampling_rate == 0.032
|
||||
|
||||
def test_audioldm2_vocoder_model_in_dim(self):
|
||||
components = self.get_dummy_components()
|
||||
audioldm_pipe = AudioLDM2Pipeline(**components)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = ["hey"]
|
||||
|
||||
output = audioldm_pipe(prompt, num_inference_steps=1)
|
||||
audio_shape = output.audios.shape
|
||||
assert audio_shape == (1, 256)
|
||||
|
||||
config = audioldm_pipe.vocoder.config
|
||||
config.model_in_dim *= 2
|
||||
audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
|
||||
output = audioldm_pipe(prompt, num_inference_steps=1)
|
||||
audio_shape = output.audios.shape
|
||||
# waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
|
||||
assert audio_shape == (1, 256)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
super().test_dict_tuple_outputs_equivalent(expected_max_difference=2e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
self._test_inference_batch_single_identical(test_mean_pixel_difference=False, expected_max_diff=2e-4)
|
||||
|
||||
def test_save_load_local(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
super().test_save_load_local(expected_max_difference=2e-4)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
# increase tolerance from 1e-4 -> 2e-4 to account for large composite model
|
||||
super().test_save_load_optional_components(expected_max_difference=2e-4)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# The method component.dtype returns the dtype of the first parameter registered in the model, not the
|
||||
# dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
|
||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||
self.assertTrue(model_dtypes["text_encoder"] == torch.float64)
|
||||
|
||||
# Without the logit scale parameters, everything is float32
|
||||
model_dtypes.pop("text_encoder")
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||
|
||||
# the CLAP sub-models are float32
|
||||
model_dtypes["clap_text_branch"] = components["text_encoder"].text_model.dtype
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
|
||||
|
||||
# Once we send to fp16, all params are in half-precision, including the logit scale
|
||||
pipe.to(torch_dtype=torch.float16)
|
||||
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
|
||||
|
||||
|
||||
@slow
|
||||
class AudioLDM2PipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
|
||||
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
|
||||
inputs = {
|
||||
"prompt": "A hammer hitting a wooden surface",
|
||||
"latents": latents,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 2.5,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_audioldm2(self):
|
||||
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 25
|
||||
audio = audioldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[17275:17285]
|
||||
expected_slice = np.array([0.0791, 0.0666, 0.1158, 0.1227, 0.1171, -0.2880, -0.1940, -0.0283, -0.0126, 0.1127])
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_audioldm2_lms(self):
|
||||
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2")
|
||||
audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
audio = audioldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[31390:31400]
|
||||
expected_slice = np.array(
|
||||
[-0.1318, -0.0577, 0.0446, -0.0573, 0.0659, 0.1074, -0.2600, 0.0080, -0.2190, -0.4301]
|
||||
)
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_audioldm2_large(self):
|
||||
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("/home/sanchit/convert-audioldm2/hub-audioldm2-large")
|
||||
audioldm_pipe = audioldm_pipe.to(torch_device)
|
||||
audioldm_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
audio = audioldm_pipe(**inputs).audios[0]
|
||||
|
||||
assert audio.ndim == 1
|
||||
assert len(audio) == 81952
|
||||
|
||||
# check the portion of the generated audio with the largest dynamic range (reduces flakiness)
|
||||
audio_slice = audio[8825:8835]
|
||||
expected_slice = np.array(
|
||||
[-0.1829, -0.1461, 0.0759, -0.1493, -0.1396, 0.5783, 0.3001, -0.3038, -0.0639, -0.2244]
|
||||
)
|
||||
max_diff = np.abs(expected_slice - audio_slice).max()
|
||||
assert max_diff < 1e-3
|
||||
Reference in New Issue
Block a user