1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into bria-fibo

This commit is contained in:
galbria
2025-10-28 09:47:41 +02:00
committed by GitHub
19 changed files with 664 additions and 47 deletions

View File

@@ -529,6 +529,8 @@
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kandinsky5
title: Kandinsky 5
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models

View File

@@ -0,0 +1,149 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 5.0
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
The model introduces several key innovations:
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
> [!TIP]
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
## Available Models
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
| model_id | Description | Use Cases |
|------------|-------------|-----------|
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
All models are available in 5-second and 10-second video generation versions.
## Kandinsky5T2VPipeline
[[autodoc]] Kandinsky5T2VPipeline
- all
- __call__
## Usage Examples
### Basic Text-to-Video Generation
```python
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.utils import export_to_video
# Load the pipeline
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
# Generate video
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### 10 second Models
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
```python
pipe = Kandinsky5T2VPipeline.from_pretrained(
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
torch_dtype=torch.bfloat16
)
pipe = pipe.to("cuda")
pipe.transformer.set_attention_backend(
"flex"
) # <--- Set attention backend to Flex
pipe.transformer.compile(
mode="max-autotune-no-cudagraphs",
dynamic=True
) # <--- Compile with max-autotune-no-cudagraphs
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=241,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
### Diffusion Distilled model
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
```python
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
output = pipe(
prompt="A beautiful sunset over mountains",
num_inference_steps=16, # <--- Model is distilled in 16 steps
guidance_scale=1.0, # <--- no CFG
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=9)
```
## Citation
```bibtex
@misc{kandinsky2025,
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
year = 2025
}
```

View File

@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
"time_projection.1.diff_b"
)
if any("head.head" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
if any("head.head" in k for k in original_state_dict):
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
f"head.head.{lora_down_key}.weight"
)
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
f"head.head.{lora_up_key}.weight"
)
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
# an identity.
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
if f"head.head.{lora_down_key}.weight" in state_dict:
logger.info(
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
)
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
).T
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:

View File

@@ -1337,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
tile_sample_stride_height = self.tile_sample_stride_height
tile_sample_stride_width = self.tile_sample_stride_width
if self.config.patch_size is not None:
sample_height = sample_height // self.config.patch_size
sample_width = sample_width // self.config.patch_size
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
else:
blend_height = self.tile_sample_min_height - tile_sample_stride_height
blend_width = self.tile_sample_min_width - tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
@@ -1353,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
decoded = self.decoder(
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1369,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if self.config.patch_size is not None:
dec = unpatchify(dec, patch_size=self.config.patch_size)
dec = torch.clamp(dec, min=-1.0, max=1.0)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,7 +717,11 @@ class FluxTransformer2DModel(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")

View File

@@ -324,6 +324,7 @@ class Kandinsky5AttnProcessor:
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
else:
attn_mask = None
@@ -335,6 +336,7 @@ class Kandinsky5AttnProcessor:
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(-2, -1)
attn_out = attn.out_layer(hidden_states)

View File

@@ -266,7 +266,7 @@ class StableDiffusion3ControlNetPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,

View File

@@ -284,7 +284,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,

View File

@@ -173,8 +173,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
)
self.prompt_template_encode_start_idx = 129
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@staticmethod
@@ -384,6 +386,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = len(prompt)
prompt = [prompt_clean(p) for p in prompt]

View File

@@ -237,7 +237,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,

View File

@@ -253,7 +253,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,

View File

@@ -248,7 +248,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,

View File

@@ -272,7 +272,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,

View File

@@ -278,7 +278,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
self.tokenizer_max_length,
max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,

View File

@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanVACETransformer3DModel`]):
Conditional Transformer to denoise the input latents.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
`transformer` is used.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
transformer ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
`transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer_2"]
_optional_components = ["transformer", "transformer_2"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
transformer: WanVACETransformer3DModel = None,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_images=None,
guidance_scale_2=None,
):
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
if self.transformer is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
elif self.transformer_2 is not None:
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
else:
raise ValueError(
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
)
if height % base != 0 or width % base != 0:
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device: Optional[torch.device] = None,
):
if video is not None:
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
base = self.vae_scale_factor_spatial * (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
video_height, video_width = self.video_processor.get_default_height_width(video[0])
if video_height * video_width > height * width:
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Generating with more than one video is not yet supported. This may be supported in the future."
)
transformer_patch_size = self.transformer.config.patch_size[1]
transformer_patch_size = (
self.transformer.config.patch_size[1]
if self.transformer is not None
else self.transformer_2.config.patch_size[1]
)
mask_list = []
for mask_, reference_images_batch in zip(mask, reference_images):
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
vace_layers = (
self.transformer.config.vace_layers
if self.transformer is not None
else self.transformer_2.config.vace_layers
)
if isinstance(conditioning_scale, (int, float)):
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
conditioning_scale = [conditioning_scale] * len(vace_layers)
if isinstance(conditioning_scale, list):
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
if len(conditioning_scale) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = torch.tensor(conditioning_scale)
if isinstance(conditioning_scale, torch.Tensor):
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
if conditioning_scale.size(0) != len(vace_layers):
raise ValueError(
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
conditioning_latents = conditioning_latents.to(transformer_dtype)
num_channels_latents = self.transformer.config.in_channels
num_channels_latents = (
self.transformer.config.in_channels
if self.transformer is not None
else self.transformer_2.config.in_channels
)
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

View File

View File

@@ -0,0 +1,306 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import (
CLIPTextConfig,
CLIPTextModel,
CLIPTokenizer,
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLProcessor,
)
from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
Kandinsky5T2VPipeline,
Kandinsky5Transformer3DModel,
)
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class Kandinsky5T2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Kandinsky5T2VPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
# Define required optional parameters for your pipeline
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
"max_sequence_length",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLHunyuanVideo(
in_channels=3,
out_channels=3,
spatial_compression_ratio=8,
temporal_compression_ratio=4,
latent_channels=4,
block_out_channels=(8, 8, 8, 8),
layers_per_block=1,
norm_num_groups=4,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
# Dummy Qwen2.5-VL model
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_hidden_size": 16,
},
hidden_size=16,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
# Dummy CLIP model
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
transformer = Kandinsky5Transformer3DModel(
in_visual_dim=4,
in_text_dim=16, # Match tiny Qwen2.5-VL hidden size
in_text_dim2=32, # Match tiny CLIP hidden size
time_dim=32,
out_visual_dim=4,
patch_size=(1, 2, 2),
model_dim=48,
ff_dim=128,
num_text_blocks=1,
num_visual_blocks=1,
axes_dims=(8, 8, 8),
visual_cond=False,
)
components = {
"transformer": transformer.eval(),
"vae": vae.eval(),
"scheduler": scheduler,
"text_encoder": text_encoder.eval(),
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2.eval(),
"tokenizer_2": tokenizer_2,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A cat dancing",
"negative_prompt": "blurry, low quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"num_frames": 5,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
# Check video shape: (batch, frames, channel, height, width)
expected_shape = (1, 5, 3, 32, 32)
self.assertEqual(video.shape, expected_shape)
# Check specific values
expected_slice = torch.tensor(
[
0.4330,
0.4254,
0.4285,
0.3835,
0.4253,
0.4196,
0.3704,
0.3714,
0.4999,
0.5346,
0.4795,
0.4637,
0.4930,
0.5124,
0.4902,
0.4570,
]
)
generated_slice = video.flatten()
# Take first 8 and last 8 values for comparison
video_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(video_slice, expected_slice, atol=1e-3),
f"video_slice: {video_slice}, expected_slice: {expected_slice}",
)
def test_inference_batch_single_identical(self):
# Override to test batch single identical with video
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-3, rtol=1e-3):
components = self.get_dummy_components()
text_component_names = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]
text_components = {k: (v if k in text_component_names else None) for k, v in components.items()}
non_text_components = {k: (v if k not in text_component_names else None) for k, v in components.items()}
pipe_with_just_text_encoder = self.pipeline_class(**text_components)
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
pipe_without_text_encoders = self.pipeline_class(**non_text_components)
pipe_without_text_encoders = pipe_without_text_encoders.to(torch_device)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
# Compute `encode_prompt()`.
# Test single prompt
prompt = "A cat dancing"
with torch.no_grad():
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
prompt, device=torch_device, max_sequence_length=16
)
# Check shapes
self.assertEqual(prompt_embeds_qwen.shape, (1, 4, 16)) # [batch, seq_len, embed_dim]
self.assertEqual(prompt_embeds_clip.shape, (1, 32)) # [batch, embed_dim]
self.assertEqual(prompt_cu_seqlens.shape, (2,)) # [batch + 1]
# Test batch of prompts
prompts = ["A cat dancing", "A dog running"]
with torch.no_grad():
batch_embeds_qwen, batch_embeds_clip, batch_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
prompts, device=torch_device, max_sequence_length=16
)
# Check batch size
self.assertEqual(batch_embeds_qwen.shape, (len(prompts), 4, 16))
self.assertEqual(batch_embeds_clip.shape, (len(prompts), 32))
self.assertEqual(len(batch_cu_seqlens), len(prompts) + 1) # [0, len1, len1+len2]
inputs = self.get_dummy_inputs(torch_device)
inputs["guidance_scale"] = 1.0
# baseline output: full pipeline
pipe_out = pipe(**inputs).frames
# test against pipeline call with pre-computed prompt embeds
inputs = self.get_dummy_inputs(torch_device)
inputs["guidance_scale"] = 1.0
with torch.no_grad():
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
inputs["prompt"], device=torch_device, max_sequence_length=inputs["max_sequence_length"]
)
inputs["prompt"] = None
inputs["prompt_embeds_qwen"] = prompt_embeds_qwen
inputs["prompt_embeds_clip"] = prompt_embeds_clip
inputs["prompt_cu_seqlens"] = prompt_cu_seqlens
pipe_out_2 = pipe_without_text_encoders(**inputs)[0]
self.assertTrue(
torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol),
f"max diff: {torch.max(torch.abs(pipe_out - pipe_out_2))}",
)
@unittest.skip("Kandinsky5T2VPipeline does not support attention slicing")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("Kandinsky5T2VPipeline does not support xformers")
def test_xformers_attention_forwardGenerator_pass(self):
pass
@unittest.skip("Kandinsky5T2VPipeline does not support VAE slicing")
def test_vae_slicing(self):
pass

View File

@@ -1461,6 +1461,8 @@ class PipelineTesterMixin:
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
if not self.pipeline_class._optional_components:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
@@ -19,9 +20,15 @@ import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
WanVACEPipeline,
WanVACETransformer3DModel,
)
from ...testing_utils import enable_full_determinism
from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
def test_save_load_float16(self):
pass
def test_inference_with_only_transformer(self):
components = self.get_dummy_components()
components["transformer_2"] = None
components["boundary_ratio"] = 0.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_inference_with_only_transformer_2(self):
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
components["transformer"] = None
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = ["transformer"]
components = self.get_dummy_components()
components["transformer_2"] = components["transformer"]
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
# because starting timestep t == 1000 == boundary_timestep
components["scheduler"] = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
)
for component in optional_component:
components[component] = None
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for component in optional_component:
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"