From 149fd53df84c42100062def55d25ca02dc023979 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 13 Oct 2025 22:38:03 +0000 Subject: [PATCH] fix prompt type --- .../kandinsky5/pipeline_kandinsky.py | 227 ++++++++++-------- 1 file changed, 130 insertions(+), 97 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a30484c701..407dc127fd 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -33,83 +33,6 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import KandinskyPipelineOutput -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -if is_ftfy_available(): - import ftfy - - -logger = logging.get_logger(__name__) - -EXAMPLE_DOC_STRING = """ - Examples: - - ```python - >>> import torch - >>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel - >>> from diffusers.utils import export_to_video - - >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") - >>> pipe = pipe.to("cuda") - - >>> prompt = "A cat and a dog baking a cake together in a kitchen." - >>> negative_prompt = "Bright tones, overexposed, static, blurred details" - - >>> output = pipe( - ... prompt=prompt, - ... negative_prompt=negative_prompt, - ... height=512, - ... width=768, - ... num_frames=25, - ... num_inference_steps=50, - ... guidance_scale=5.0, - ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=6) - ``` -""" - -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import html -from typing import Any, Callable, Dict, List, Optional, Union - -import regex as re -import torch -from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer -import torchvision -from torchvision.transforms import ToPILImage - -from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import KandinskyLoraLoaderMixin -from ...models import AutoencoderKLHunyuanVideo -from ...models.transformers import Kandinsky5Transformer3DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from ...video_processor import VideoProcessor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import KandinskyPipelineOutput - - if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -137,23 +60,23 @@ EXAMPLE_DOC_STRING = """ >>> pipe = pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen." - >>> negative_prompt = "Bright tones, overexposed, static, blurred details" - + >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=512, ... width=768, - ... num_frames=25, + ... num_frames=121, ... num_inference_steps=50, ... guidance_scale=5.0, ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=6) + >>> export_to_video(output, "output.mp4", fps=24) ``` """ def basic_clean(text): + """Clean text using ftfy if available and unescape HTML entities.""" if is_ftfy_available(): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) @@ -161,12 +84,14 @@ def basic_clean(text): def whitespace_clean(text): + """Normalize whitespace in text by replacing multiple spaces with single space.""" text = re.sub(r"\s+", " ", text) text = text.strip() return text def prompt_clean(text): + """Apply both basic cleaning and whitespace normalization to prompts.""" text = whitespace_clean(basic_clean(text)) return text @@ -228,6 +153,24 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): def fast_sta_nabla( T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" ) -> torch.Tensor: + """ + Create a sparse temporal attention (STA) mask for efficient video generation. + + This method generates a mask that limits attention to nearby frames and spatial positions, + reducing computational complexity for video generation. + + Args: + T (int): Number of temporal frames + H (int): Height in latent space + W (int): Width in latent space + wT (int): Temporal attention window size + wH (int): Height attention window size + wW (int): Width attention window size + device (str): Device to create tensor on + + Returns: + torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) + """ l = torch.Tensor([T, H, W]).amax() r = torch.arange(0, l, 1, dtype=torch.int16, device=device) mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() @@ -253,6 +196,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): return sta.reshape(T * H * W, T * H * W) def get_sparse_params(self, sample, device): + """ + Generate sparse attention parameters for the transformer based on sample dimensions. + + This method computes the sparse attention configuration needed for efficient + video processing in the transformer model. + + Args: + sample (torch.Tensor): Input sample tensor + device (torch.device): Device to place tensors on + + Returns: + Dict: Dictionary containing sparse attention parameters + """ assert self.transformer.config.patch_size[0] == 1 B, T, H, W, _ = sample.shape T, H, W = ( @@ -294,12 +250,28 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): max_sequence_length: int = 256, dtype: Optional[torch.dtype] = None, ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate + text embeddings suitable for video generation. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt_clean(p) for p in prompt] - # Kandinsky specific prompt template + # Kandinsky specific prompt template for detailed video description prompt_template = "\n".join([ "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", @@ -310,7 +282,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): "Pay attention to the order of key actions shown in the scene.<|im_end|>", "<|im_start|>user\n{}<|im_end|>", ]) - crop_start = 129 + crop_start = 129 # Position to start cropping from (system prompt tokens) full_texts = [prompt_template.format(p) for p in prompt] @@ -347,6 +319,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): num_videos_per_prompt: int = 1, dtype: Optional[torch.dtype] = None, ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate + pooled embeddings that capture semantic information. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -386,6 +373,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): r""" Encodes the prompt into text encoder hidden states. + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders + to create comprehensive text representations for video generation. + Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded @@ -410,11 +400,15 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): torch device dtype: (`torch.dtype`, *optional*): torch dtype + + Returns: + Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information """ device = device or self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: @@ -438,7 +432,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): @@ -492,6 +486,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + height: Video height + width: Video width + prompt_embeds: Pre-computed prompt embeddings + negative_prompt_embeds: Pre-computed negative prompt embeddings + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -535,6 +544,26 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """ + Prepare initial latent variables for video generation. + + This method creates random noise latents or uses provided latents as starting point + for the denoising process. + + Args: + batch_size (int): Number of videos to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated video + width (int): Width of generated video + num_frames (int): Number of frames in video + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor + """ if latents is not None: return latents.to(device=device, dtype=dtype) @@ -568,18 +597,22 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): @property def guidance_scale(self): + """Get the current guidance scale value.""" return self._guidance_scale @property def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled.""" return self._guidance_scale > 1.0 @property def num_timesteps(self): + """Get the number of denoising timesteps.""" return self._num_timesteps @property def interrupt(self): + """Check if generation has been interrupted.""" return self._interrupt @torch.no_grad() @@ -590,10 +623,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 25, + num_frames: int = 121, num_inference_steps: int = 50, guidance_scale: float = 5.0, - scheduler_scale: float = 10.0, + scheduler_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -715,7 +748,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = 16 + num_channels_latents = self.transformer.config.in_visual_dim latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -728,7 +761,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): latents, ) - # 6. Prepare rope positions + # 6. Prepare rope positions for positional encoding num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 visual_rope_pos = [ torch.arange(num_latent_frames, device=device), @@ -744,7 +777,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): else None ) - # 7. Sparse Params + # 7. Sparse Params for efficient attention sparse_params = self.get_sparse_params(latents, device) # 8. Denoising loop @@ -788,9 +821,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): pred_velocity - uncond_pred_velocity ) - # Compute previous sample - latents[:, :, :, :, :16] = self.scheduler.step( - pred_velocity, t, latents[:, :, :, :, :16], return_dict=False + # Compute previous sample using the scheduler + latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False )[0] if callback_on_step_end is not None: @@ -809,8 +842,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): if XLA_AVAILABLE: xm.mark_step() - # 8. Post-processing - latents = latents[:, :, :, :, :16] + # 8. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] # 9. Decode latents to video if output_type != "latent": @@ -822,18 +855,18 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, - 16, + num_channels_latents, ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] video = video.reshape( batch_size * num_videos_per_prompt, - 16, + num_channels_latents, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial ) - # Normalize and decode + # Normalize and decode through VAE video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample video = self.video_processor.postprocess_video(video, output_type=output_type)