1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

AudioLDM2 Fixes (#11244)

This commit is contained in:
hlky
2025-04-09 09:42:00 +01:00
committed by GitHub
parent fd02aad402
commit 9ee3dd3862
2 changed files with 22 additions and 7 deletions

View File

@@ -20,7 +20,7 @@ import torch
from transformers import (
ClapFeatureExtractor,
ClapModel,
GPT2Model,
GPT2LMHeadModel,
RobertaTokenizer,
RobertaTokenizerFast,
SpeechT5HifiGan,
@@ -196,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
text_encoder: ClapModel,
text_encoder_2: Union[T5EncoderModel, VitsModel],
projection_model: AudioLDM2ProjectionModel,
language_model: GPT2Model,
language_model: GPT2LMHeadModel,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
feature_extractor: ClapFeatureExtractor,
@@ -259,7 +259,10 @@ class AudioLDM2Pipeline(DiffusionPipeline):
)
device_type = torch_device.type
device = torch.device(f"{device_type}:{gpu_id or torch_device.index}")
device_str = device_type
if gpu_id or torch_device.index:
device_str = f"{device_str}:{gpu_id or torch_device.index}"
device = torch.device(device_str)
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
@@ -316,9 +319,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
# forward pass to get next hidden states
output = self.language_model(**model_inputs, return_dict=True)
output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
next_hidden_states = output.last_hidden_state
next_hidden_states = output.hidden_states[-1]
# Update the model input
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)

View File

@@ -26,7 +26,7 @@ from transformers import (
ClapModel,
ClapTextConfig,
GPT2Config,
GPT2Model,
GPT2LMHeadModel,
RobertaTokenizer,
SpeechT5HifiGan,
SpeechT5HifiGanConfig,
@@ -162,7 +162,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
n_ctx=99,
n_positions=99,
)
language_model = GPT2Model(language_model_config)
language_model = GPT2LMHeadModel(language_model_config)
language_model.config.max_new_tokens = 8
torch.manual_seed(0)
@@ -516,6 +516,18 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_encode_prompt_works_in_isolation(self):
pass
@unittest.skip("Not supported yet due to CLAPModel.")
def test_sequential_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
def test_cpu_offload_forward_pass_twice(self):
pass
@unittest.skip("Not supported yet. `vocoder` is not offloaded.")
def test_model_cpu_offload_forward_pass(self):
pass
@nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase):