mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'bria-fibo' of github.com:galbria/diffusers into bria-fibo
This commit is contained in:
@@ -529,6 +529,8 @@
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kandinsky5
|
||||
title: Kandinsky 5
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
|
||||
149
docs/source/en/api/pipelines/kandinsky5.md
Normal file
149
docs/source/en/api/pipelines/kandinsky5.md
Normal file
@@ -0,0 +1,149 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kandinsky 5.0
|
||||
|
||||
Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
|
||||
|
||||
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
|
||||
|
||||
The model introduces several key innovations:
|
||||
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
|
||||
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
|
||||
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
|
||||
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
|
||||
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
|
||||
|
||||
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
|
||||
|
||||
> [!TIP]
|
||||
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
|
||||
|
||||
## Available Models
|
||||
|
||||
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
|
||||
|
||||
| model_id | Description | Use Cases |
|
||||
|------------|-------------|-----------|
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
|
||||
|
||||
All models are available in 5-second and 10-second video generation versions.
|
||||
|
||||
## Kandinsky5T2VPipeline
|
||||
|
||||
[[autodoc]] Kandinsky5T2VPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Text-to-Video Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Load the pipeline
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# Generate video
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=512,
|
||||
width=768,
|
||||
num_frames=121, # ~5 seconds at 24fps
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### 10 second Models
|
||||
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
|
||||
|
||||
```python
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(
|
||||
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.transformer.set_attention_backend(
|
||||
"flex"
|
||||
) # <--- Set attention backend to Flex
|
||||
pipe.transformer.compile(
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=True
|
||||
) # <--- Compile with max-autotune-no-cudagraphs
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=512,
|
||||
width=768,
|
||||
num_frames=241,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### Diffusion Distilled model
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```):
|
||||
|
||||
```python
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
output = pipe(
|
||||
prompt="A beautiful sunset over mountains",
|
||||
num_inference_steps=16, # <--- Model is distilled in 16 steps
|
||||
guidance_scale=1.0, # <--- no CFG
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
```bibtex
|
||||
@misc{kandinsky2025,
|
||||
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
|
||||
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
|
||||
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
|
||||
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
|
||||
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
|
||||
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
|
||||
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
|
||||
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
|
||||
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
|
||||
year = 2025
|
||||
}
|
||||
```
|
||||
@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
"time_projection.1.diff_b"
|
||||
)
|
||||
|
||||
if any("head.head" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
|
||||
if any("head.head" in k for k in original_state_dict):
|
||||
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_up_key}.weight"
|
||||
)
|
||||
if "head.head.diff_b" in original_state_dict:
|
||||
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
|
||||
|
||||
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
|
||||
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
|
||||
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
|
||||
# an identity.
|
||||
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
|
||||
if f"head.head.{lora_down_key}.weight" in state_dict:
|
||||
logger.info(
|
||||
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
|
||||
)
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
|
||||
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
|
||||
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
|
||||
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
|
||||
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
|
||||
).T
|
||||
|
||||
for text_time in ["text_embedding", "time_embedding"]:
|
||||
if any(text_time in k for k in original_state_dict):
|
||||
for b_n in [0, 2]:
|
||||
|
||||
@@ -1337,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
||||
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
||||
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
||||
|
||||
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
||||
tile_sample_stride_height = self.tile_sample_stride_height
|
||||
tile_sample_stride_width = self.tile_sample_stride_width
|
||||
if self.config.patch_size is not None:
|
||||
sample_height = sample_height // self.config.patch_size
|
||||
sample_width = sample_width // self.config.patch_size
|
||||
tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
|
||||
tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
|
||||
blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
|
||||
else:
|
||||
blend_height = self.tile_sample_min_height - tile_sample_stride_height
|
||||
blend_width = self.tile_sample_min_width - tile_sample_stride_width
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
@@ -1353,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
self._conv_idx = [0]
|
||||
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
||||
decoded = self.decoder(
|
||||
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
|
||||
)
|
||||
time.append(decoded)
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
@@ -1369,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_width)
|
||||
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
||||
result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
|
||||
result_rows.append(torch.cat(result_row, dim=-1))
|
||||
|
||||
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
||||
|
||||
if self.config.patch_size is not None:
|
||||
dec = unpatchify(dec, patch_size=self.config.patch_size)
|
||||
|
||||
dec = torch.clamp(dec, min=-1.0, max=1.0)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -717,7 +717,11 @@ class FluxTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
|
||||
@@ -324,6 +324,7 @@ class Kandinsky5AttnProcessor:
|
||||
sparse_params["sta_mask"],
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@@ -335,6 +336,7 @@ class Kandinsky5AttnProcessor:
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(-2, -1)
|
||||
|
||||
attn_out = attn.out_layer(hidden_states)
|
||||
|
||||
@@ -266,7 +266,7 @@ class StableDiffusion3ControlNetPipeline(
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@@ -284,7 +284,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@@ -173,8 +173,10 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
)
|
||||
self.prompt_template_encode_start_idx = 129
|
||||
|
||||
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
|
||||
self.vae_scale_factor_temporal = (
|
||||
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
||||
)
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
@staticmethod
|
||||
@@ -384,6 +386,9 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
@@ -237,7 +237,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@@ -253,7 +253,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@@ -248,7 +248,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@@ -272,7 +272,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@@ -278,7 +278,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
max_sequence_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
|
||||
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanVACETransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
||||
`transformer` is used.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
transformer ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
|
||||
`transformer` or `transformer_2` must be provided.
|
||||
transformer_2 ([`WanVACETransformer3DModel`], *optional*):
|
||||
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
||||
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
|
||||
`transformer` or `transformer_2` must be provided.
|
||||
boundary_ratio (`float`, *optional*, defaults to `None`):
|
||||
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
||||
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
||||
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
||||
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
||||
boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
_optional_components = ["transformer_2"]
|
||||
_optional_components = ["transformer", "transformer_2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
transformer: WanVACETransformer3DModel = None,
|
||||
transformer_2: WanVACETransformer3DModel = None,
|
||||
boundary_ratio: Optional[float] = None,
|
||||
):
|
||||
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
reference_images=None,
|
||||
guidance_scale_2=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
elif self.transformer_2 is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
"`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
|
||||
)
|
||||
|
||||
if height % base != 0 or width % base != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
|
||||
|
||||
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if video is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
base = self.vae_scale_factor_spatial * (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
video_height, video_width = self.video_processor.get_default_height_width(video[0])
|
||||
|
||||
if video_height * video_width > height * width:
|
||||
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
transformer_patch_size = self.transformer.config.patch_size[1]
|
||||
transformer_patch_size = (
|
||||
self.transformer.config.patch_size[1]
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.patch_size[1]
|
||||
)
|
||||
|
||||
mask_list = []
|
||||
for mask_, reference_images_batch in zip(mask, reference_images):
|
||||
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
||||
|
||||
vace_layers = (
|
||||
self.transformer.config.vace_layers
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.vace_layers
|
||||
)
|
||||
if isinstance(conditioning_scale, (int, float)):
|
||||
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
|
||||
conditioning_scale = [conditioning_scale] * len(vace_layers)
|
||||
if isinstance(conditioning_scale, list):
|
||||
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
|
||||
if len(conditioning_scale) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = torch.tensor(conditioning_scale)
|
||||
if isinstance(conditioning_scale, torch.Tensor):
|
||||
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
|
||||
if conditioning_scale.size(0) != len(vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
|
||||
)
|
||||
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
|
||||
|
||||
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
num_channels_latents = (
|
||||
self.transformer.config.in_channels
|
||||
if self.transformer is not None
|
||||
else self.transformer_2.config.in_channels
|
||||
)
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
0
tests/pipelines/kandinsky5/__init__.py
Normal file
0
tests/pipelines/kandinsky5/__init__.py
Normal file
306
tests/pipelines/kandinsky5/test_kandinsky5.py
Normal file
306
tests/pipelines/kandinsky5/test_kandinsky5.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextConfig,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLProcessor,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Kandinsky5T2VPipeline,
|
||||
Kandinsky5Transformer3DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Kandinsky5T2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Kandinsky5T2VPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
# Define required optional parameters for your pipeline
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
"max_sequence_length",
|
||||
]
|
||||
)
|
||||
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLHunyuanVideo(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
spatial_compression_ratio=8,
|
||||
temporal_compression_ratio=4,
|
||||
latent_channels=4,
|
||||
block_out_channels=(8, 8, 8, 8),
|
||||
layers_per_block=1,
|
||||
norm_num_groups=4,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
|
||||
# Dummy Qwen2.5-VL model
|
||||
config = Qwen2_5_VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_hidden_size": 16,
|
||||
},
|
||||
hidden_size=16,
|
||||
vocab_size=152064,
|
||||
vision_end_token_id=151653,
|
||||
vision_start_token_id=151652,
|
||||
vision_token_id=151654,
|
||||
)
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
|
||||
tokenizer = Qwen2VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
# Dummy CLIP model
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = Kandinsky5Transformer3DModel(
|
||||
in_visual_dim=4,
|
||||
in_text_dim=16, # Match tiny Qwen2.5-VL hidden size
|
||||
in_text_dim2=32, # Match tiny CLIP hidden size
|
||||
time_dim=32,
|
||||
out_visual_dim=4,
|
||||
patch_size=(1, 2, 2),
|
||||
model_dim=48,
|
||||
ff_dim=128,
|
||||
num_text_blocks=1,
|
||||
num_visual_blocks=1,
|
||||
axes_dims=(8, 8, 8),
|
||||
visual_cond=False,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer.eval(),
|
||||
"vae": vae.eval(),
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder.eval(),
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2.eval(),
|
||||
"tokenizer_2": tokenizer_2,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A cat dancing",
|
||||
"negative_prompt": "blurry, low quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 5,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
|
||||
# Check video shape: (batch, frames, channel, height, width)
|
||||
expected_shape = (1, 5, 3, 32, 32)
|
||||
self.assertEqual(video.shape, expected_shape)
|
||||
|
||||
# Check specific values
|
||||
expected_slice = torch.tensor(
|
||||
[
|
||||
0.4330,
|
||||
0.4254,
|
||||
0.4285,
|
||||
0.3835,
|
||||
0.4253,
|
||||
0.4196,
|
||||
0.3704,
|
||||
0.3714,
|
||||
0.4999,
|
||||
0.5346,
|
||||
0.4795,
|
||||
0.4637,
|
||||
0.4930,
|
||||
0.5124,
|
||||
0.4902,
|
||||
0.4570,
|
||||
]
|
||||
)
|
||||
|
||||
generated_slice = video.flatten()
|
||||
# Take first 8 and last 8 values for comparison
|
||||
video_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(
|
||||
torch.allclose(video_slice, expected_slice, atol=1e-3),
|
||||
f"video_slice: {video_slice}, expected_slice: {expected_slice}",
|
||||
)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
# Override to test batch single identical with video
|
||||
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
|
||||
|
||||
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-3, rtol=1e-3):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
text_component_names = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]
|
||||
text_components = {k: (v if k in text_component_names else None) for k, v in components.items()}
|
||||
non_text_components = {k: (v if k not in text_component_names else None) for k, v in components.items()}
|
||||
|
||||
pipe_with_just_text_encoder = self.pipeline_class(**text_components)
|
||||
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
|
||||
|
||||
pipe_without_text_encoders = self.pipeline_class(**non_text_components)
|
||||
pipe_without_text_encoders = pipe_without_text_encoders.to(torch_device)
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
# Compute `encode_prompt()`.
|
||||
|
||||
# Test single prompt
|
||||
prompt = "A cat dancing"
|
||||
with torch.no_grad():
|
||||
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
|
||||
prompt, device=torch_device, max_sequence_length=16
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
self.assertEqual(prompt_embeds_qwen.shape, (1, 4, 16)) # [batch, seq_len, embed_dim]
|
||||
self.assertEqual(prompt_embeds_clip.shape, (1, 32)) # [batch, embed_dim]
|
||||
self.assertEqual(prompt_cu_seqlens.shape, (2,)) # [batch + 1]
|
||||
|
||||
# Test batch of prompts
|
||||
prompts = ["A cat dancing", "A dog running"]
|
||||
with torch.no_grad():
|
||||
batch_embeds_qwen, batch_embeds_clip, batch_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
|
||||
prompts, device=torch_device, max_sequence_length=16
|
||||
)
|
||||
|
||||
# Check batch size
|
||||
self.assertEqual(batch_embeds_qwen.shape, (len(prompts), 4, 16))
|
||||
self.assertEqual(batch_embeds_clip.shape, (len(prompts), 32))
|
||||
self.assertEqual(len(batch_cu_seqlens), len(prompts) + 1) # [0, len1, len1+len2]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["guidance_scale"] = 1.0
|
||||
|
||||
# baseline output: full pipeline
|
||||
pipe_out = pipe(**inputs).frames
|
||||
|
||||
# test against pipeline call with pre-computed prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["guidance_scale"] = 1.0
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
|
||||
inputs["prompt"], device=torch_device, max_sequence_length=inputs["max_sequence_length"]
|
||||
)
|
||||
|
||||
inputs["prompt"] = None
|
||||
inputs["prompt_embeds_qwen"] = prompt_embeds_qwen
|
||||
inputs["prompt_embeds_clip"] = prompt_embeds_clip
|
||||
inputs["prompt_cu_seqlens"] = prompt_cu_seqlens
|
||||
|
||||
pipe_out_2 = pipe_without_text_encoders(**inputs)[0]
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol),
|
||||
f"max diff: {torch.max(torch.abs(pipe_out - pipe_out_2))}",
|
||||
)
|
||||
|
||||
@unittest.skip("Kandinsky5T2VPipeline does not support attention slicing")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Kandinsky5T2VPipeline does not support xformers")
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Kandinsky5T2VPipeline does not support VAE slicing")
|
||||
def test_vae_slicing(self):
|
||||
pass
|
||||
@@ -1461,6 +1461,8 @@ class PipelineTesterMixin:
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
return
|
||||
if not self.pipeline_class._optional_components:
|
||||
return
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -19,9 +20,15 @@ import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
WanVACEPipeline,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
|
||||
def test_inference_with_only_transformer(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = None
|
||||
components["boundary_ratio"] = 0.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
def test_inference_with_only_transformer_2(self):
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = components["transformer"]
|
||||
components["transformer"] = None
|
||||
|
||||
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
|
||||
# because starting timestep t == 1000 == boundary_timestep
|
||||
components["scheduler"] = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
|
||||
)
|
||||
|
||||
components["boundary_ratio"] = 1.0
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
assert video.shape == (17, 3, 16, 16)
|
||||
|
||||
def test_save_load_optional_components(self, expected_max_difference=1e-4):
|
||||
optional_component = ["transformer"]
|
||||
|
||||
components = self.get_dummy_components()
|
||||
components["transformer_2"] = components["transformer"]
|
||||
# FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
|
||||
# because starting timestep t == 1000 == boundary_timestep
|
||||
components["scheduler"] = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
|
||||
)
|
||||
for component in optional_component:
|
||||
components[component] = None
|
||||
|
||||
components["boundary_ratio"] = 1.0
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=False)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
for component in pipe_loaded.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for component in optional_component:
|
||||
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
|
||||
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
torch.manual_seed(0)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
|
||||
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"
|
||||
|
||||
Reference in New Issue
Block a user