1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix prompt type

This commit is contained in:
leffff
2025-10-13 22:38:03 +00:00
parent 43bd1e81d2
commit 149fd53df8

View File

@@ -33,83 +33,6 @@ from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import KandinskyPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel
>>> from diffusers.utils import export_to_video
>>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
>>> pipe = pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=512,
... width=768,
... num_frames=25,
... num_inference_steps=50,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=6)
```
"""
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import Any, Callable, Dict, List, Optional, Union
import regex as re
import torch
from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer
import torchvision
from torchvision.transforms import ToPILImage
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import KandinskyLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo
from ...models.transformers import Kandinsky5Transformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import KandinskyPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -137,23 +60,23 @@ EXAMPLE_DOC_STRING = """
>>> pipe = pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
>>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=512,
... width=768,
... num_frames=25,
... num_frames=121,
... num_inference_steps=50,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=6)
>>> export_to_video(output, "output.mp4", fps=24)
```
"""
def basic_clean(text):
"""Clean text using ftfy if available and unescape HTML entities."""
if is_ftfy_available():
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
@@ -161,12 +84,14 @@ def basic_clean(text):
def whitespace_clean(text):
"""Normalize whitespace in text by replacing multiple spaces with single space."""
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
"""Apply both basic cleaning and whitespace normalization to prompts."""
text = whitespace_clean(basic_clean(text))
return text
@@ -228,6 +153,24 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
def fast_sta_nabla(
T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda"
) -> torch.Tensor:
"""
Create a sparse temporal attention (STA) mask for efficient video generation.
This method generates a mask that limits attention to nearby frames and spatial positions,
reducing computational complexity for video generation.
Args:
T (int): Number of temporal frames
H (int): Height in latent space
W (int): Width in latent space
wT (int): Temporal attention window size
wH (int): Height attention window size
wW (int): Width attention window size
device (str): Device to create tensor on
Returns:
torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W)
"""
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
@@ -253,6 +196,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
return sta.reshape(T * H * W, T * H * W)
def get_sparse_params(self, sample, device):
"""
Generate sparse attention parameters for the transformer based on sample dimensions.
This method computes the sparse attention configuration needed for efficient
video processing in the transformer model.
Args:
sample (torch.Tensor): Input sample tensor
device (torch.device): Device to place tensors on
Returns:
Dict: Dictionary containing sparse attention parameters
"""
assert self.transformer.config.patch_size[0] == 1
B, T, H, W, _ = sample.shape
T, H, W = (
@@ -294,12 +250,28 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
max_sequence_length: int = 256,
dtype: Optional[torch.dtype] = None,
):
"""
Encode prompt using Qwen2.5-VL text encoder.
This method processes the input prompt through the Qwen2.5-VL model to generate
text embeddings suitable for video generation.
Args:
prompt (Union[str, List[str]]): Input prompt or list of prompts
device (torch.device): Device to run encoding on
num_videos_per_prompt (int): Number of videos to generate per prompt
max_sequence_length (int): Maximum sequence length for tokenization
dtype (torch.dtype): Data type for embeddings
Returns:
Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(p) for p in prompt]
# Kandinsky specific prompt template
# Kandinsky specific prompt template for detailed video description
prompt_template = "\n".join([
"<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
"Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
@@ -310,7 +282,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
"Pay attention to the order of key actions shown in the scene.<|im_end|>",
"<|im_start|>user\n{}<|im_end|>",
])
crop_start = 129
crop_start = 129 # Position to start cropping from (system prompt tokens)
full_texts = [prompt_template.format(p) for p in prompt]
@@ -347,6 +319,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
num_videos_per_prompt: int = 1,
dtype: Optional[torch.dtype] = None,
):
"""
Encode prompt using CLIP text encoder.
This method processes the input prompt through the CLIP model to generate
pooled embeddings that capture semantic information.
Args:
prompt (Union[str, List[str]]): Input prompt or list of prompts
device (torch.device): Device to run encoding on
num_videos_per_prompt (int): Number of videos to generate per prompt
dtype (torch.dtype): Data type for embeddings
Returns:
torch.Tensor: Pooled text embeddings from CLIP
"""
device = device or self._execution_device
dtype = dtype or self.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
@@ -386,6 +373,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
r"""
Encodes the prompt into text encoder hidden states.
This method combines embeddings from both Qwen2.5-VL and CLIP text encoders
to create comprehensive text representations for video generation.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
@@ -410,11 +400,15 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
Returns:
Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information
"""
device = device or self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
prompt = [prompt]
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
@@ -438,7 +432,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
@@ -492,6 +486,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
"""
Validate input parameters for the pipeline.
Args:
prompt: Input prompt
negative_prompt: Negative prompt for guidance
height: Video height
width: Video width
prompt_embeds: Pre-computed prompt embeddings
negative_prompt_embeds: Pre-computed negative prompt embeddings
callback_on_step_end_tensor_inputs: Callback tensor inputs
Raises:
ValueError: If inputs are invalid
"""
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -535,6 +544,26 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Prepare initial latent variables for video generation.
This method creates random noise latents or uses provided latents as starting point
for the denoising process.
Args:
batch_size (int): Number of videos to generate
num_channels_latents (int): Number of channels in latent space
height (int): Height of generated video
width (int): Width of generated video
num_frames (int): Number of frames in video
dtype (torch.dtype): Data type for latents
device (torch.device): Device to create latents on
generator (torch.Generator): Random number generator
latents (torch.Tensor): Pre-existing latents to use
Returns:
torch.Tensor: Prepared latent tensor
"""
if latents is not None:
return latents.to(device=device, dtype=dtype)
@@ -568,18 +597,22 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
@property
def guidance_scale(self):
"""Get the current guidance scale value."""
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
"""Check if classifier-free guidance is enabled."""
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
"""Get the number of denoising timesteps."""
return self._num_timesteps
@property
def interrupt(self):
"""Check if generation has been interrupted."""
return self._interrupt
@torch.no_grad()
@@ -590,10 +623,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 25,
num_frames: int = 121,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
scheduler_scale: float = 10.0,
scheduler_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -715,7 +748,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = 16
num_channels_latents = self.transformer.config.in_visual_dim
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -728,7 +761,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
latents,
)
# 6. Prepare rope positions
# 6. Prepare rope positions for positional encoding
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
visual_rope_pos = [
torch.arange(num_latent_frames, device=device),
@@ -744,7 +777,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
else None
)
# 7. Sparse Params
# 7. Sparse Params for efficient attention
sparse_params = self.get_sparse_params(latents, device)
# 8. Denoising loop
@@ -788,9 +821,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
pred_velocity - uncond_pred_velocity
)
# Compute previous sample
latents[:, :, :, :, :16] = self.scheduler.step(
pred_velocity, t, latents[:, :, :, :, :16], return_dict=False
# Compute previous sample using the scheduler
latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False
)[0]
if callback_on_step_end is not None:
@@ -809,8 +842,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
# 8. Post-processing
latents = latents[:, :, :, :, :16]
# 8. Post-processing - extract main latents
latents = latents[:, :, :, :, :num_channels_latents]
# 9. Decode latents to video
if output_type != "latent":
@@ -822,18 +855,18 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
16,
num_channels_latents,
)
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
video = video.reshape(
batch_size * num_videos_per_prompt,
16,
num_channels_latents,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial
)
# Normalize and decode
# Normalize and decode through VAE
video = video / self.vae.config.scaling_factor
video = self.vae.decode(video).sample
video = self.video_processor.postprocess_video(video, output_type=output_type)