1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Propose to update & upgrade SkyReels-V2 (#12167)

* fix: update SkyReels-V2 documentation and moving into attn dispatcher

* Refactors SkyReelsV2's attention implementation

* style

* up

* Fixes formatting in SkyReels-V2 documentation

Wraps the visual demonstration section in a Markdown code block.

This change corrects the rendering of ASCII diagrams and examples, improving the overall readability of the document.

* Docs: Condense example arrays in skyreels_v2 guide

Improves the readability of the `step_matrix` examples by replacing long sequences of repeated numbers with a more compact `valueΓ—count` notation.

This change makes the underlying data patterns in the examples easier to understand at a glance.

* Add _repeated_blocks attribute to SkyReelsV2Transformer3DModel

* Refactor rotary embedding calculations in SkyReelsV2 to separate cosine and sine frequencies

* Enhance SkyReels-V2 documentation: update model loading for GPU support and remove outdated notes

* up

* up

* Update model_id in SkyReels-V2 documentation

* up

* refactor: remove device_map parameter for model loading and add pipeline.to("cuda") for GPU allocation

* fix: update copyright year to 2025 in skyreels_v2.md

* docs: enhance parameter examples and formatting in skyreels_v2.md

* docs: update example formatting and add notes on LoRA support in skyreels_v2.md

* refactor: remove copied comments from transformer_wan in SkyReelsV2 classes

* Clean up comments in skyreels_v2.md

Removed comments about acceleration helpers and Flash Attention installation.

* Add deprecation warning for `SkyReelsV2AttnProcessor2_0` class
This commit is contained in:
Tolga CangΓΆz
2025-08-26 10:24:19 +03:00
committed by GitHub
parent 0fd7ee79ea
commit 5fcd5f560f
2 changed files with 365 additions and 212 deletions

View File

@@ -1,4 +1,4 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@
# SkyReels-V2: Infinite-length Film Generative model
[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team.
[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team from Skywork AI.
*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).*
@@ -44,93 +44,113 @@ The following SkyReels-V2 models are supported in Diffusers:
### A _Visual_ Demonstration
An example with these parameters:
base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5
The example below has the following parameters:
vae_scale_factor_temporal -> 4
num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each
- `base_num_frames=97`
- `num_frames=97`
- `num_inference_steps=30`
- `ar_step=5`
- `causal_block_size=5`
base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 β†’ blocks = 25//5 = 5 blocks
This 5 blocks means the maximum context length of the model is 25 frames in the latent space.
With `vae_scale_factor_temporal=4`, expect `5` blocks of `5` frames each as calculated by:
Asynchronous Processing Timeline:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Steps: 1 6 11 16 21 26 31 36 41 46 50 β”‚
β”‚ Block 1: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 2: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 3: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 4: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 5: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
`num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each`
For Long Videos (num_frames > base_num_frames):
base_num_frames acts as the "sliding window size" for processing long videos.
And the maximum context length in the latent space is calculated with `base_num_latent_frames`:
Example: 257-frame video with base_num_frames=97, overlap_history=17
β”Œβ”€β”€β”€β”€ Iteration 1 (frames 1-97) ────┐
β”‚ Processing window: 97 frames β”‚ β†’ 5 blocks, async processing
β”‚ Generates: frames 1-97 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€ Iteration 2 (frames 81-177) ──────┐
β”‚ Processing window: 97 frames β”‚
β”‚ Overlap: 17 frames (81-97) from prev β”‚ β†’ 5 blocks, async processing
β”‚ Generates: frames 98-177 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€ Iteration 3 (frames 161-257) ──────┐
β”‚ Processing window: 97 frames β”‚
β”‚ Overlap: 17 frames (161-177) from prev β”‚ β†’ 5 blocks, async processing
β”‚ Generates: frames 178-257 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
`base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 -> 25//5 = 5 blocks`
Each iteration independently runs the asynchronous processing with its own 5 blocks.
base_num_frames controls:
1. Memory usage (larger window = more VRAM)
2. Model context length (must match training constraints)
3. Number of blocks per iteration (base_num_latent_frames // causal_block_size)
Asynchronous Processing Timeline:
```text
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Steps: 1 6 11 16 21 26 31 36 41 46 50 β”‚
β”‚ Block 1: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 2: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 3: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 4: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β”‚ Block 5: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
Each block takes 30 steps to complete denoising.
Block N starts at step: 1 + (N-1) x ar_step
Total steps: 30 + (5-1) x 5 = 50 steps
For Long Videos (`num_frames` > `base_num_frames`):
`base_num_frames` acts as the "sliding window size" for processing long videos.
Example: `257`-frame video with `base_num_frames=97`, `overlap_history=17`
```text
β”Œβ”€β”€β”€β”€ Iteration 1 (frames 1-97) ────┐
β”‚ Processing window: 97 frames β”‚ β†’ 5 blocks,
β”‚ Generates: frames 1-97 β”‚ async processing
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€ Iteration 2 (frames 81-177) ──────┐
β”‚ Processing window: 97 frames β”‚
β”‚ Overlap: 17 frames (81-97) from prev β”‚ β†’ 5 blocks,
β”‚ Generates: frames 98-177 β”‚ async processing
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ”€β”€β”€β”€β”€β”€ Iteration 3 (frames 161-257) ──────┐
β”‚ Processing window: 97 frames β”‚
β”‚ Overlap: 17 frames (161-177) from prev β”‚ β†’ 5 blocks,
β”‚ Generates: frames 178-257 β”‚ async processing
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
Each iteration independently runs the asynchronous processing with its own `5` blocks.
`base_num_frames` controls:
1. Memory usage (larger window = more VRAM)
2. Model context length (must match training constraints)
3. Number of blocks per iteration (`base_num_latent_frames // causal_block_size`)
Each block takes `30` steps to complete denoising.
Block N starts at step: `1 + (N-1) x ar_step`
Total steps: `30 + (5-1) x 5 = 50` steps
Synchronous mode (ar_step=0) would process all blocks/frames simultaneously:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Steps: 1 ... 30 β”‚
β”‚ All blocks: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Total steps: 30 steps
Synchronous mode (`ar_step=0`) would process all blocks/frames simultaneously:
```text
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ Steps: 1 ... 30 β”‚
β”‚ All blocks: [β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– β– ] β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```
Total steps: `30` steps
An example on how the step matrix is constructed for asynchronous processing:
Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5)
- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
An example on how the step matrix is constructed for asynchronous processing:
Given the parameters: (`num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5`)
```
- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
```
The algorithm creates a 50x25 step_matrix where:
- Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- ...
- Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
- ...
- Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999]
- ...
- Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901]
- ...
- Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773]
- ...
- Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216]
The algorithm creates a `50x25` `step_matrix` where:
```
- Row 1: [999Γ—5, 999Γ—5, 999Γ—5, 999Γ—5, 999Γ—5]
- Row 2: [995Γ—5, 999Γ—5, 999Γ—5, 999Γ—5, 999Γ—5]
- Row 3: [991Γ—5, 999Γ—5, 999Γ—5, 999Γ—5, 999Γ—5]
- ...
- Row 7: [969Γ—5, 995Γ—5, 999Γ—5, 999Γ—5, 999Γ—5]
- ...
- Row 21: [799Γ—5, 888Γ—5, 941Γ—5, 975Γ—5, 999Γ—5]
- ...
- Row 35: [ 0Γ—5, 216Γ—5, 666Γ—5, 822Γ—5, 901Γ—5]
- ...
- Row 42: [ 0Γ—5, 0Γ—5, 0Γ—5, 551Γ—5, 773Γ—5]
- ...
- Row 50: [ 0Γ—5, 0Γ—5, 0Γ—5, 0Γ—5, 216Γ—5]
```
Detailed Row 6 Analysis:
- step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999]
- step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0]
- step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False]
- valid_interval[5]: (0, 25)
Detailed Row `6` Analysis:
```
- step_matrix[5]: [ 975Γ—5, 999Γ—5, 999Γ—5, 999Γ—5, 999Γ—5]
- step_index[5]: [ 6Γ—5, 1Γ—5, 0Γ—5, 0Γ—5, 0Γ—5]
- step_update_mask[5]: [TrueΓ—5, TrueΓ—5, FalseΓ—5, FalseΓ—5, FalseΓ—5]
- valid_interval[5]: (0, 25)
```
Key Pattern: Block `i` lags behind Block `i-1` by exactly `ar_step=5` timesteps, creating the
staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the
staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
### Text-to-Video Generation
@@ -145,23 +165,22 @@ From the original repo:
>You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
```py
# pip install ftfy
import torch
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32)
transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
model_id = "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers"
vae = AutoModel.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
"Skywork/SkyReels-V2-DF-14B-540P-Diffusers",
model_id,
vae=vae,
transformer=transformer,
torch_dtype=torch.bfloat16
torch_dtype=torch.bfloat16,
)
pipeline.to("cuda")
flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline = pipeline.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
@@ -177,7 +196,7 @@ output = pipeline(
overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations
addnoise_condition=20, # Improves consistency in long video generation
).frames[0]
export_to_video(output, "T2V.mp4", fps=24, quality=8)
export_to_video(output, "video.mp4", fps=24, quality=8)
```
</hfoption>
@@ -198,14 +217,14 @@ from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPi
from diffusers.utils import export_to_video, load_image
model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers"
model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
@@ -239,7 +258,7 @@ prompt = "CG animation style, a small blue bird takes off from the ground, flapp
output = pipeline(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
export_to_video(output, "video.mp4", fps=24, quality=8)
```
</hfoption>
@@ -261,75 +280,35 @@ from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPi
from diffusers.utils import export_to_video, load_video
model_id = "Skywork/SkyReels-V2-DF-14B-540P-Diffusers"
model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
pipeline.to("cuda")
video = load_video("input_video.mp4")
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
output = pipeline(
video=video, prompt=prompt, height=544, width=960, guidance_scale=5.0,
num_inference_steps=30, num_frames=257, base_num_frames=97#, ar_step=5, causal_block_size=5,
video=video, prompt=prompt, height=720, width=1280, guidance_scale=5.0, overlap_history=17,
num_inference_steps=30, num_frames=257, base_num_frames=121#, ar_step=5, causal_block_size=5,
).frames[0]
export_to_video(output, "output.mp4", fps=24, quality=8)
# Total frames will be the number of frames of given video + 257
export_to_video(output, "video.mp4", fps=24, quality=8)
# Total frames will be the number of frames of the given video + 257
```
</hfoption>
</hfoptions>
## Notes
- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].
<details>
<summary>Show example code</summary>
```py
# pip install ftfy
import torch
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained(
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie")
pipeline.set_adapters("steamboat-willie")
pipeline.enable_model_cpu_offload()
# use "steamboat willie style" to trigger the LoRA
prompt = """
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
"""
output = pipeline(
prompt=prompt,
num_frames=97,
guidance_scale=6.0,
).frames[0]
export_to_video(output, "output.mp4", fps=24)
```
</details>
`SkyReelsV2Pipeline` and `SkyReelsV2ImageToVideoPipeline` are also available without Diffusion Forcing framework applied.
## SkyReelsV2DiffusionForcingPipeline
@@ -364,4 +343,4 @@ export_to_video(output, "output.mp4", fps=24, quality=8)
## SkyReelsV2PipelineOutput
[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput
[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput

View File

@@ -1,4 +1,4 @@
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
# Copyright 2025 The SkyReels Team, The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,9 +21,10 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import (
PixArtAlphaTextProjection,
@@ -39,20 +40,53 @@ from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SkyReelsV2AttnProcessor2_0:
def _get_qkv_projections(
attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
):
# encoder_hidden_states is only passed for cross-attention
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.fused_projections:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
# In cross-attention layers, we can only fuse the KV projections into a single linear
query = attn.to_q(hidden_states)
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
return query, key, value
def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: torch.Tensor):
if attn.fused_projections:
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
else:
key_img = attn.add_k_proj(encoder_hidden_states_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
return key_img, value_img
class SkyReelsV2AttnProcessor:
_attention_backend = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
"SkyReelsV2AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
attn: "SkyReelsV2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
@@ -60,58 +94,66 @@ class SkyReelsV2AttnProcessor2_0:
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = attn.norm_q(query)
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
def apply_rotary_emb(
hidden_states: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key_img = key_img.unflatten(2, (attn.heads, -1))
value_img = value_img.unflatten(2, (attn.heads, -1))
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
hidden_states_img = dispatch_attention_fn(
query,
key_img,
value_img,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = F.scaled_dot_product_attention(
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
@@ -122,7 +164,122 @@ class SkyReelsV2AttnProcessor2_0:
return hidden_states
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
class SkyReelsV2AttnProcessor2_0:
def __new__(cls, *args, **kwargs):
deprecation_message = (
"The SkyReelsV2AttnProcessor2_0 class is deprecated and will be removed in a future version. "
"Please use SkyReelsV2AttnProcessor instead. "
)
deprecate("SkyReelsV2AttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
return SkyReelsV2AttnProcessor(*args, **kwargs)
class SkyReelsV2Attention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = SkyReelsV2AttnProcessor
_available_processors = [SkyReelsV2AttnProcessor]
def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
eps: float = 1e-5,
dropout: float = 0.0,
added_kv_proj_dim: Optional[int] = None,
cross_attention_dim_head: Optional[int] = None,
processor=None,
is_cross_attention=None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.ModuleList(
[
torch.nn.Linear(self.inner_dim, dim, bias=True),
torch.nn.Dropout(dropout),
]
)
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.add_k_proj = self.add_v_proj = None
if added_kv_proj_dim is not None:
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
def fuse_projections(self):
if getattr(self, "fused_projections", False):
return
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
self.to_qkv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_kv = nn.Linear(in_features, out_features, bias=True)
self.to_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
self.to_added_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
if not getattr(self, "fused_projections", False):
return
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_kv"):
delattr(self, "to_added_kv")
self.fused_projections = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
class SkyReelsV2ImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
@@ -213,7 +370,11 @@ class SkyReelsV2TimeTextImageEmbedding(nn.Module):
class SkyReelsV2RotaryPosEmbed(nn.Module):
def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
@@ -223,37 +384,55 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
freqs = []
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
freqs_cos.append(freq_cos)
freqs_sin.append(freq_sin)
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
dim=1,
)
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class SkyReelsV2TransformerBlock(nn.Module):
def __init__(
self,
@@ -269,33 +448,24 @@ class SkyReelsV2TransformerBlock(nn.Module):
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
self.attn1 = SkyReelsV2Attention(
dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=SkyReelsV2AttnProcessor2_0(),
cross_attention_dim_head=None,
processor=SkyReelsV2AttnProcessor(),
)
# 2. Cross-attention
self.attn2 = Attention(
query_dim=dim,
self.attn2 = SkyReelsV2Attention(
dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=SkyReelsV2AttnProcessor2_0(),
cross_attention_dim_head=dim // num_heads,
processor=SkyReelsV2AttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -321,15 +491,15 @@ class SkyReelsV2TransformerBlock(nn.Module):
# For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(
hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
)
attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
@@ -338,10 +508,13 @@ class SkyReelsV2TransformerBlock(nn.Module):
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
class SkyReelsV2Transformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
r"""
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
@@ -389,6 +562,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
_no_split_modules = ["SkyReelsV2TransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["SkyReelsV2TransformerBlock"]
@register_to_config
def __init__(