1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Support num_videos_per_prompt for prompt embeddings

This commit is contained in:
Daniel Gu
2025-12-23 07:01:17 +01:00
parent 6e6ce20595
commit cbb10b8dca

View File

@@ -255,10 +255,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 1024,
scale_factor: int = 8,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -272,7 +273,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
torch dtype to cast the prompt embeds to
max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt.
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.base_text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if getattr(self, "tokenizer", None) is not None:
# Gemma expects left padding for chat-style prompts
@@ -301,6 +306,18 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
prompt_embeds = prompt_embeds.to(dtype=dtype)
audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
_, audio_seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, audio_prompt_embeds, prompt_attention_mask
def encode_prompt(
@@ -310,10 +327,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
audio_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 128,
max_sequence_length: int = 1024,
scale_factor: int = 8,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
@@ -356,6 +376,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
scale_factor=scale_factor,
device=device,
dtype=dtype,
)
@@ -380,6 +401,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
scale_factor=scale_factor,
device=device,
dtype=dtype,
)
@@ -650,8 +672,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
latents: Optional[torch.Tensor] = None,
audio_latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
audio_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_audio_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
@@ -712,11 +736,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
audio_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
@@ -797,7 +827,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
audio_prompt_embeds=audio_prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
negative_audio_prompt_embeds=negative_audio_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,