From 3ffdf7f113e442c68d65da5033e31a195f7a1be7 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 10:32:47 +0000 Subject: [PATCH] separated prompt encoding --- .../kandinsky5/pipeline_kandinsky.py | 153 +++++++----------- 1 file changed, 56 insertions(+), 97 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index b1f7924e9b..2ff0c1d45d 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -359,124 +359,64 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): def encode_prompt( self, prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): r""" - Encodes the prompt into text encoder hidden states. + Encodes a single prompt (positive or negative) into text encoder hidden states. This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text representations for video generation. - + Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. + prompt (`str` or `List[str]`): + Prompt to be encoded. num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + Number of videos to generate per prompt. max_sequence_length (`int`, *optional*, defaults to 512): Maximum sequence length for text encoding. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + Returns: - Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information + Tuple[Dict[str, torch.Tensor], torch.Tensor]: + - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings """ device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] + batch_size = len(prompt) - if prompt_embeds is None: - prompt = [prompt_clean(p) for p in prompt] - - prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( - prompt=prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - dtype=dtype, - ) - prompt_embeds_clip = self._encode_prompt_clip( - prompt=prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - dtype=dtype, - ) - else: - prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds + prompt = [prompt_clean(p) for p in prompt] - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" - negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt = [prompt_clean(p) for p in negative_prompt] - - negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( - prompt=negative_prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - dtype=dtype, - ) - negative_prompt_embeds_clip = self._encode_prompt_clip( - prompt=negative_prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - dtype=dtype, - ) - else: - negative_prompt_embeds_qwen = None - negative_prompt_embeds_clip = None - negative_cu_seqlens = None + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) prompt_embeds_dict = { "text_embeds": prompt_embeds_qwen, "pooled_embed": prompt_embeds_clip, } - negative_prompt_embeds_dict = { - "text_embeds": negative_prompt_embeds_qwen, - "pooled_embed": negative_prompt_embeds_clip, - } if do_classifier_free_guidance else None - return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens + return prompt_embeds_dict, prompt_cu_seqlens def check_inputs( self, @@ -722,24 +662,43 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + prompt_embeds_dict, prompt_cu_seqlens = self.encode_prompt( prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) + negative_prompt_embeds_dict = None + negative_cu_seqlens = None + + if self.do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + negative_prompt_embeds_dict, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps