mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix prompt type
This commit is contained in:
@@ -33,83 +33,6 @@ from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import KandinskyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=768,
|
||||
... num_frames=25,
|
||||
... num_inference_steps=50,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=6)
|
||||
```
|
||||
"""
|
||||
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer
|
||||
import torchvision
|
||||
from torchvision.transforms import ToPILImage
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import KandinskyLoraLoaderMixin
|
||||
from ...models import AutoencoderKLHunyuanVideo
|
||||
from ...models.transformers import Kandinsky5Transformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import KandinskyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@@ -137,23 +60,23 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details"
|
||||
|
||||
>>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=768,
|
||||
... num_frames=25,
|
||||
... num_frames=121,
|
||||
... num_inference_steps=50,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=6)
|
||||
>>> export_to_video(output, "output.mp4", fps=24)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
"""Clean text using ftfy if available and unescape HTML entities."""
|
||||
if is_ftfy_available():
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
@@ -161,12 +84,14 @@ def basic_clean(text):
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
"""Normalize whitespace in text by replacing multiple spaces with single space."""
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
"""Apply both basic cleaning and whitespace normalization to prompts."""
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
@@ -228,6 +153,24 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
def fast_sta_nabla(
|
||||
T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a sparse temporal attention (STA) mask for efficient video generation.
|
||||
|
||||
This method generates a mask that limits attention to nearby frames and spatial positions,
|
||||
reducing computational complexity for video generation.
|
||||
|
||||
Args:
|
||||
T (int): Number of temporal frames
|
||||
H (int): Height in latent space
|
||||
W (int): Width in latent space
|
||||
wT (int): Temporal attention window size
|
||||
wH (int): Height attention window size
|
||||
wW (int): Width attention window size
|
||||
device (str): Device to create tensor on
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W)
|
||||
"""
|
||||
l = torch.Tensor([T, H, W]).amax()
|
||||
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
|
||||
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
|
||||
@@ -253,6 +196,19 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
return sta.reshape(T * H * W, T * H * W)
|
||||
|
||||
def get_sparse_params(self, sample, device):
|
||||
"""
|
||||
Generate sparse attention parameters for the transformer based on sample dimensions.
|
||||
|
||||
This method computes the sparse attention configuration needed for efficient
|
||||
video processing in the transformer model.
|
||||
|
||||
Args:
|
||||
sample (torch.Tensor): Input sample tensor
|
||||
device (torch.device): Device to place tensors on
|
||||
|
||||
Returns:
|
||||
Dict: Dictionary containing sparse attention parameters
|
||||
"""
|
||||
assert self.transformer.config.patch_size[0] == 1
|
||||
B, T, H, W, _ = sample.shape
|
||||
T, H, W = (
|
||||
@@ -294,12 +250,28 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
max_sequence_length: int = 256,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Encode prompt using Qwen2.5-VL text encoder.
|
||||
|
||||
This method processes the input prompt through the Qwen2.5-VL model to generate
|
||||
text embeddings suitable for video generation.
|
||||
|
||||
Args:
|
||||
prompt (Union[str, List[str]]): Input prompt or list of prompts
|
||||
device (torch.device): Device to run encoding on
|
||||
num_videos_per_prompt (int): Number of videos to generate per prompt
|
||||
max_sequence_length (int): Maximum sequence length for tokenization
|
||||
dtype (torch.dtype): Data type for embeddings
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
# Kandinsky specific prompt template
|
||||
# Kandinsky specific prompt template for detailed video description
|
||||
prompt_template = "\n".join([
|
||||
"<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
|
||||
"Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
|
||||
@@ -310,7 +282,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
"Pay attention to the order of key actions shown in the scene.<|im_end|>",
|
||||
"<|im_start|>user\n{}<|im_end|>",
|
||||
])
|
||||
crop_start = 129
|
||||
crop_start = 129 # Position to start cropping from (system prompt tokens)
|
||||
|
||||
full_texts = [prompt_template.format(p) for p in prompt]
|
||||
|
||||
@@ -347,6 +319,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
num_videos_per_prompt: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Encode prompt using CLIP text encoder.
|
||||
|
||||
This method processes the input prompt through the CLIP model to generate
|
||||
pooled embeddings that capture semantic information.
|
||||
|
||||
Args:
|
||||
prompt (Union[str, List[str]]): Input prompt or list of prompts
|
||||
device (torch.device): Device to run encoding on
|
||||
num_videos_per_prompt (int): Number of videos to generate per prompt
|
||||
dtype (torch.dtype): Data type for embeddings
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Pooled text embeddings from CLIP
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder_2.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
@@ -386,6 +373,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
This method combines embeddings from both Qwen2.5-VL and CLIP text encoders
|
||||
to create comprehensive text representations for video generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
@@ -410,11 +400,15 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
|
||||
Returns:
|
||||
Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
prompt = [prompt]
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
@@ -438,7 +432,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
@@ -492,6 +486,21 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
"""
|
||||
Validate input parameters for the pipeline.
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
negative_prompt: Negative prompt for guidance
|
||||
height: Video height
|
||||
width: Video width
|
||||
prompt_embeds: Pre-computed prompt embeddings
|
||||
negative_prompt_embeds: Pre-computed negative prompt embeddings
|
||||
callback_on_step_end_tensor_inputs: Callback tensor inputs
|
||||
|
||||
Raises:
|
||||
ValueError: If inputs are invalid
|
||||
"""
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
@@ -535,6 +544,26 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare initial latent variables for video generation.
|
||||
|
||||
This method creates random noise latents or uses provided latents as starting point
|
||||
for the denoising process.
|
||||
|
||||
Args:
|
||||
batch_size (int): Number of videos to generate
|
||||
num_channels_latents (int): Number of channels in latent space
|
||||
height (int): Height of generated video
|
||||
width (int): Width of generated video
|
||||
num_frames (int): Number of frames in video
|
||||
dtype (torch.dtype): Data type for latents
|
||||
device (torch.device): Device to create latents on
|
||||
generator (torch.Generator): Random number generator
|
||||
latents (torch.Tensor): Pre-existing latents to use
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Prepared latent tensor
|
||||
"""
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
@@ -568,18 +597,22 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
"""Get the current guidance scale value."""
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
"""Check if classifier-free guidance is enabled."""
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
"""Get the number of denoising timesteps."""
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
"""Check if generation has been interrupted."""
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -590,10 +623,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 25,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
scheduler_scale: float = 10.0,
|
||||
scheduler_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
@@ -715,7 +748,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = 16
|
||||
num_channels_latents = self.transformer.config.in_visual_dim
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -728,7 +761,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare rope positions
|
||||
# 6. Prepare rope positions for positional encoding
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
visual_rope_pos = [
|
||||
torch.arange(num_latent_frames, device=device),
|
||||
@@ -744,7 +777,7 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
else None
|
||||
)
|
||||
|
||||
# 7. Sparse Params
|
||||
# 7. Sparse Params for efficient attention
|
||||
sparse_params = self.get_sparse_params(latents, device)
|
||||
|
||||
# 8. Denoising loop
|
||||
@@ -788,9 +821,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
|
||||
# Compute previous sample
|
||||
latents[:, :, :, :, :16] = self.scheduler.step(
|
||||
pred_velocity, t, latents[:, :, :, :, :16], return_dict=False
|
||||
# Compute previous sample using the scheduler
|
||||
latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
|
||||
pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
@@ -809,8 +842,8 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# 8. Post-processing
|
||||
latents = latents[:, :, :, :, :16]
|
||||
# 8. Post-processing - extract main latents
|
||||
latents = latents[:, :, :, :, :num_channels_latents]
|
||||
|
||||
# 9. Decode latents to video
|
||||
if output_type != "latent":
|
||||
@@ -822,18 +855,18 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
16,
|
||||
num_channels_latents,
|
||||
)
|
||||
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
|
||||
video = video.reshape(
|
||||
batch_size * num_videos_per_prompt,
|
||||
16,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial
|
||||
)
|
||||
|
||||
# Normalize and decode
|
||||
# Normalize and decode through VAE
|
||||
video = video / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(video).sample
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
Reference in New Issue
Block a user