mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
AudioLDM2 Fixes (#11244)
This commit is contained in:
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import (
|
||||
ClapFeatureExtractor,
|
||||
ClapModel,
|
||||
GPT2Model,
|
||||
GPT2LMHeadModel,
|
||||
RobertaTokenizer,
|
||||
RobertaTokenizerFast,
|
||||
SpeechT5HifiGan,
|
||||
@@ -196,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
text_encoder: ClapModel,
|
||||
text_encoder_2: Union[T5EncoderModel, VitsModel],
|
||||
projection_model: AudioLDM2ProjectionModel,
|
||||
language_model: GPT2Model,
|
||||
language_model: GPT2LMHeadModel,
|
||||
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
|
||||
tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
|
||||
feature_extractor: ClapFeatureExtractor,
|
||||
@@ -259,7 +259,10 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{gpu_id or torch_device.index}")
|
||||
device_str = device_type
|
||||
if gpu_id or torch_device.index:
|
||||
device_str = f"{device_str}:{gpu_id or torch_device.index}"
|
||||
device = torch.device(device_str)
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
@@ -316,9 +319,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
# forward pass to get next hidden states
|
||||
output = self.language_model(**model_inputs, return_dict=True)
|
||||
output = self.language_model(**model_inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
next_hidden_states = output.last_hidden_state
|
||||
next_hidden_states = output.hidden_states[-1]
|
||||
|
||||
# Update the model input
|
||||
inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers import (
|
||||
ClapModel,
|
||||
ClapTextConfig,
|
||||
GPT2Config,
|
||||
GPT2Model,
|
||||
GPT2LMHeadModel,
|
||||
RobertaTokenizer,
|
||||
SpeechT5HifiGan,
|
||||
SpeechT5HifiGanConfig,
|
||||
@@ -162,7 +162,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
n_ctx=99,
|
||||
n_positions=99,
|
||||
)
|
||||
language_model = GPT2Model(language_model_config)
|
||||
language_model = GPT2LMHeadModel(language_model_config)
|
||||
language_model.config.max_new_tokens = 8
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -516,6 +516,18 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported yet due to CLAPModel.")
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded.")
|
||||
def test_cpu_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported yet. `vocoder` is not offloaded.")
|
||||
def test_model_cpu_offload_forward_pass(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
class AudioLDM2PipelineSlowTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user