mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -40,7 +40,6 @@ from ...utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.import_utils import is_transformers_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
@@ -313,19 +312,8 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.0.dev0"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user