From cbb10b8dcae1ea9588fbb31aaadb7c60d3bba27f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 23 Dec 2025 07:01:17 +0100 Subject: [PATCH] Support num_videos_per_prompt for prompt embeddings --- src/diffusers/pipelines/ltx2/pipeline_ltx2.py | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py index 250ff7284f..af9b0096fd 100644 --- a/src/diffusers/pipelines/ltx2/pipeline_ltx2.py +++ b/src/diffusers/pipelines/ltx2/pipeline_ltx2.py @@ -255,10 +255,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], - device: torch.device, - dtype: torch.dtype, + num_videos_per_prompt: int = 1, max_sequence_length: int = 1024, scale_factor: int = 8, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -272,7 +273,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix torch dtype to cast the prompt embeds to max_sequence_length (`int`, defaults to 1024): Maximum sequence length to use for the prompt. """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.base_text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if getattr(self, "tokenizer", None) is not None: # Gemma expects left padding for chat-style prompts @@ -301,6 +306,18 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix prompt_embeds = prompt_embeds.to(dtype=dtype) audio_prompt_embeds = audio_prompt_embeds.to(dtype=dtype) + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + _, audio_seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, audio_seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + return prompt_embeds, audio_prompt_embeds, prompt_attention_mask def encode_prompt( @@ -310,10 +327,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, - max_sequence_length: int = 128, + max_sequence_length: int = 1024, + scale_factor: int = 8, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -356,6 +376,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix prompt=prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, + scale_factor=scale_factor, device=device, dtype=dtype, ) @@ -380,6 +401,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix prompt=negative_prompt, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, + scale_factor=scale_factor, device=device, dtype=dtype, ) @@ -650,8 +672,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix latents: Optional[torch.Tensor] = None, audio_latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, + audio_prompt_embeds: Optional[torch.Tensor] = None, prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_audio_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, decode_timestep: Union[float, List[float]] = 0.0, decode_noise_scale: Optional[Union[float, List[float]]] = None, @@ -712,11 +736,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. + audio_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings for audio processing. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_audio_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings for audio processing. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. decode_timestep (`float`, defaults to `0.0`): @@ -797,7 +827,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMix do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, + audio_prompt_embeds=audio_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + negative_audio_prompt_embeds=negative_audio_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length,