mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Propose to update & upgrade SkyReels-V2 (#12167)
* fix: update SkyReels-V2 documentation and moving into attn dispatcher
* Refactors SkyReelsV2's attention implementation
* style
* up
* Fixes formatting in SkyReels-V2 documentation
Wraps the visual demonstration section in a Markdown code block.
This change corrects the rendering of ASCII diagrams and examples, improving the overall readability of the document.
* Docs: Condense example arrays in skyreels_v2 guide
Improves the readability of the `step_matrix` examples by replacing long sequences of repeated numbers with a more compact `valueΓcount` notation.
This change makes the underlying data patterns in the examples easier to understand at a glance.
* Add _repeated_blocks attribute to SkyReelsV2Transformer3DModel
* Refactor rotary embedding calculations in SkyReelsV2 to separate cosine and sine frequencies
* Enhance SkyReels-V2 documentation: update model loading for GPU support and remove outdated notes
* up
* up
* Update model_id in SkyReels-V2 documentation
* up
* refactor: remove device_map parameter for model loading and add pipeline.to("cuda") for GPU allocation
* fix: update copyright year to 2025 in skyreels_v2.md
* docs: enhance parameter examples and formatting in skyreels_v2.md
* docs: update example formatting and add notes on LoRA support in skyreels_v2.md
* refactor: remove copied comments from transformer_wan in SkyReelsV2 classes
* Clean up comments in skyreels_v2.md
Removed comments about acceleration helpers and Flash Attention installation.
* Add deprecation warning for `SkyReelsV2AttnProcessor2_0` class
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -22,7 +22,7 @@
|
||||
|
||||
# SkyReels-V2: Infinite-length Film Generative model
|
||||
|
||||
[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team.
|
||||
[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team from Skywork AI.
|
||||
|
||||
*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).*
|
||||
|
||||
@@ -44,93 +44,113 @@ The following SkyReels-V2 models are supported in Diffusers:
|
||||
|
||||
### A _Visual_ Demonstration
|
||||
|
||||
An example with these parameters:
|
||||
base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5
|
||||
The example below has the following parameters:
|
||||
|
||||
vae_scale_factor_temporal -> 4
|
||||
num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each
|
||||
- `base_num_frames=97`
|
||||
- `num_frames=97`
|
||||
- `num_inference_steps=30`
|
||||
- `ar_step=5`
|
||||
- `causal_block_size=5`
|
||||
|
||||
base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 β blocks = 25//5 = 5 blocks
|
||||
This 5 blocks means the maximum context length of the model is 25 frames in the latent space.
|
||||
With `vae_scale_factor_temporal=4`, expect `5` blocks of `5` frames each as calculated by:
|
||||
|
||||
Asynchronous Processing Timeline:
|
||||
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
β Steps: 1 6 11 16 21 26 31 36 41 46 50 β
|
||||
β Block 1: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 2: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 3: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 4: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 5: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
`num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each`
|
||||
|
||||
For Long Videos (num_frames > base_num_frames):
|
||||
base_num_frames acts as the "sliding window size" for processing long videos.
|
||||
And the maximum context length in the latent space is calculated with `base_num_latent_frames`:
|
||||
|
||||
Example: 257-frame video with base_num_frames=97, overlap_history=17
|
||||
βββββ Iteration 1 (frames 1-97) βββββ
|
||||
β Processing window: 97 frames β β 5 blocks, async processing
|
||||
β Generates: frames 1-97 β
|
||||
βββββββββββββββββββββββββββββββββββββ
|
||||
βββββββ Iteration 2 (frames 81-177) βββββββ
|
||||
β Processing window: 97 frames β
|
||||
β Overlap: 17 frames (81-97) from prev β β 5 blocks, async processing
|
||||
β Generates: frames 98-177 β
|
||||
βββββββββββββββββββββββββββββββββββββββββββ
|
||||
βββββββ Iteration 3 (frames 161-257) βββββββ
|
||||
β Processing window: 97 frames β
|
||||
β Overlap: 17 frames (161-177) from prev β β 5 blocks, async processing
|
||||
β Generates: frames 178-257 β
|
||||
ββββββββββββββββββββββββββββββββββββββββββββ
|
||||
`base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 -> 25//5 = 5 blocks`
|
||||
|
||||
Each iteration independently runs the asynchronous processing with its own 5 blocks.
|
||||
base_num_frames controls:
|
||||
1. Memory usage (larger window = more VRAM)
|
||||
2. Model context length (must match training constraints)
|
||||
3. Number of blocks per iteration (base_num_latent_frames // causal_block_size)
|
||||
Asynchronous Processing Timeline:
|
||||
```text
|
||||
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
β Steps: 1 6 11 16 21 26 31 36 41 46 50 β
|
||||
β Block 1: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 2: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 3: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 4: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
β Block 5: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
```
|
||||
|
||||
Each block takes 30 steps to complete denoising.
|
||||
Block N starts at step: 1 + (N-1) x ar_step
|
||||
Total steps: 30 + (5-1) x 5 = 50 steps
|
||||
For Long Videos (`num_frames` > `base_num_frames`):
|
||||
`base_num_frames` acts as the "sliding window size" for processing long videos.
|
||||
|
||||
Example: `257`-frame video with `base_num_frames=97`, `overlap_history=17`
|
||||
```text
|
||||
βββββ Iteration 1 (frames 1-97) βββββ
|
||||
β Processing window: 97 frames β β 5 blocks,
|
||||
β Generates: frames 1-97 β async processing
|
||||
βββββββββββββββββββββββββββββββββββββ
|
||||
βββββββ Iteration 2 (frames 81-177) βββββββ
|
||||
β Processing window: 97 frames β
|
||||
β Overlap: 17 frames (81-97) from prev β β 5 blocks,
|
||||
β Generates: frames 98-177 β async processing
|
||||
βββββββββββββββββββββββββββββββββββββββββββ
|
||||
βββββββ Iteration 3 (frames 161-257) βββββββ
|
||||
β Processing window: 97 frames β
|
||||
β Overlap: 17 frames (161-177) from prev β β 5 blocks,
|
||||
β Generates: frames 178-257 β async processing
|
||||
ββββββββββββββββββββββββββββββββββββββββββββ
|
||||
```
|
||||
|
||||
Each iteration independently runs the asynchronous processing with its own `5` blocks.
|
||||
`base_num_frames` controls:
|
||||
1. Memory usage (larger window = more VRAM)
|
||||
2. Model context length (must match training constraints)
|
||||
3. Number of blocks per iteration (`base_num_latent_frames // causal_block_size`)
|
||||
|
||||
Each block takes `30` steps to complete denoising.
|
||||
Block N starts at step: `1 + (N-1) x ar_step`
|
||||
Total steps: `30 + (5-1) x 5 = 50` steps
|
||||
|
||||
|
||||
Synchronous mode (ar_step=0) would process all blocks/frames simultaneously:
|
||||
ββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
β Steps: 1 ... 30 β
|
||||
β All blocks: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
ββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
Total steps: 30 steps
|
||||
Synchronous mode (`ar_step=0`) would process all blocks/frames simultaneously:
|
||||
```text
|
||||
ββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
β Steps: 1 ... 30 β
|
||||
β All blocks: [β β β β β β β β β β β β β β β β β β β β β β β β β β β β β β ] β
|
||||
ββββββββββββββββββββββββββββββββββββββββββββββββ
|
||||
```
|
||||
Total steps: `30` steps
|
||||
|
||||
|
||||
An example on how the step matrix is constructed for asynchronous processing:
|
||||
Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5)
|
||||
- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
|
||||
- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
|
||||
941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
|
||||
799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
|
||||
An example on how the step matrix is constructed for asynchronous processing:
|
||||
Given the parameters: (`num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5`)
|
||||
```
|
||||
- num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
|
||||
- step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
|
||||
941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
|
||||
799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
|
||||
```
|
||||
|
||||
The algorithm creates a 50x25 step_matrix where:
|
||||
- Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
|
||||
- Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
|
||||
- Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
|
||||
- ...
|
||||
- Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
|
||||
- ...
|
||||
- Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999]
|
||||
- ...
|
||||
- Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901]
|
||||
- ...
|
||||
- Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773]
|
||||
- ...
|
||||
- Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216]
|
||||
The algorithm creates a `50x25` `step_matrix` where:
|
||||
```
|
||||
- Row 1: [999Γ5, 999Γ5, 999Γ5, 999Γ5, 999Γ5]
|
||||
- Row 2: [995Γ5, 999Γ5, 999Γ5, 999Γ5, 999Γ5]
|
||||
- Row 3: [991Γ5, 999Γ5, 999Γ5, 999Γ5, 999Γ5]
|
||||
- ...
|
||||
- Row 7: [969Γ5, 995Γ5, 999Γ5, 999Γ5, 999Γ5]
|
||||
- ...
|
||||
- Row 21: [799Γ5, 888Γ5, 941Γ5, 975Γ5, 999Γ5]
|
||||
- ...
|
||||
- Row 35: [ 0Γ5, 216Γ5, 666Γ5, 822Γ5, 901Γ5]
|
||||
- ...
|
||||
- Row 42: [ 0Γ5, 0Γ5, 0Γ5, 551Γ5, 773Γ5]
|
||||
- ...
|
||||
- Row 50: [ 0Γ5, 0Γ5, 0Γ5, 0Γ5, 216Γ5]
|
||||
```
|
||||
|
||||
Detailed Row 6 Analysis:
|
||||
- step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999]
|
||||
- step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0]
|
||||
- step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False]
|
||||
- valid_interval[5]: (0, 25)
|
||||
Detailed Row `6` Analysis:
|
||||
```
|
||||
- step_matrix[5]: [ 975Γ5, 999Γ5, 999Γ5, 999Γ5, 999Γ5]
|
||||
- step_index[5]: [ 6Γ5, 1Γ5, 0Γ5, 0Γ5, 0Γ5]
|
||||
- step_update_mask[5]: [TrueΓ5, TrueΓ5, FalseΓ5, FalseΓ5, FalseΓ5]
|
||||
- valid_interval[5]: (0, 25)
|
||||
```
|
||||
|
||||
Key Pattern: Block `i` lags behind Block `i-1` by exactly `ar_step=5` timesteps, creating the
|
||||
staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
|
||||
|
||||
Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the
|
||||
staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
|
||||
|
||||
### Text-to-Video Generation
|
||||
|
||||
@@ -145,23 +165,22 @@ From the original repo:
|
||||
>You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
|
||||
|
||||
```py
|
||||
# pip install ftfy
|
||||
import torch
|
||||
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
|
||||
model_id = "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers"
|
||||
vae = AutoModel.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
|
||||
"Skywork/SkyReels-V2-DF-14B-540P-Diffusers",
|
||||
model_id,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
|
||||
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
|
||||
pipeline = pipeline.to("cuda")
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
|
||||
@@ -177,7 +196,7 @@ output = pipeline(
|
||||
overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations
|
||||
addnoise_condition=20, # Improves consistency in long video generation
|
||||
).frames[0]
|
||||
export_to_video(output, "T2V.mp4", fps=24, quality=8)
|
||||
export_to_video(output, "video.mp4", fps=24, quality=8)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@@ -198,14 +217,14 @@ from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPi
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
|
||||
|
||||
model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers"
|
||||
model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
|
||||
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
|
||||
pipeline.to("cuda")
|
||||
|
||||
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
|
||||
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
|
||||
@@ -239,7 +258,7 @@ prompt = "CG animation style, a small blue bird takes off from the ground, flapp
|
||||
output = pipeline(
|
||||
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=24, quality=8)
|
||||
export_to_video(output, "video.mp4", fps=24, quality=8)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@@ -261,75 +280,35 @@ from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPi
|
||||
from diffusers.utils import export_to_video, load_video
|
||||
|
||||
|
||||
model_id = "Skywork/SkyReels-V2-DF-14B-540P-Diffusers"
|
||||
model_id = "Skywork/SkyReels-V2-DF-1.3B-720P-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
|
||||
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
|
||||
pipeline.to("cuda")
|
||||
|
||||
video = load_video("input_video.mp4")
|
||||
|
||||
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
|
||||
output = pipeline(
|
||||
video=video, prompt=prompt, height=544, width=960, guidance_scale=5.0,
|
||||
num_inference_steps=30, num_frames=257, base_num_frames=97#, ar_step=5, causal_block_size=5,
|
||||
video=video, prompt=prompt, height=720, width=1280, guidance_scale=5.0, overlap_history=17,
|
||||
num_inference_steps=30, num_frames=257, base_num_frames=121#, ar_step=5, causal_block_size=5,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=24, quality=8)
|
||||
# Total frames will be the number of frames of given video + 257
|
||||
export_to_video(output, "video.mp4", fps=24, quality=8)
|
||||
# Total frames will be the number of frames of the given video + 257
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].
|
||||
|
||||
<details>
|
||||
<summary>Show example code</summary>
|
||||
|
||||
```py
|
||||
# pip install ftfy
|
||||
import torch
|
||||
from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
vae = AutoModel.from_pretrained(
|
||||
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
|
||||
"Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie")
|
||||
pipeline.set_adapters("steamboat-willie")
|
||||
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
# use "steamboat willie style" to trigger the LoRA
|
||||
prompt = """
|
||||
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
num_frames=97,
|
||||
guidance_scale=6.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</details>
|
||||
`SkyReelsV2Pipeline` and `SkyReelsV2ImageToVideoPipeline` are also available without Diffusion Forcing framework applied.
|
||||
|
||||
|
||||
## SkyReelsV2DiffusionForcingPipeline
|
||||
@@ -364,4 +343,4 @@ export_to_video(output, "output.mp4", fps=24, quality=8)
|
||||
|
||||
## SkyReelsV2PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput
|
||||
[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2025 The SkyReels Team, The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -21,9 +21,10 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
PixArtAlphaTextProjection,
|
||||
@@ -39,20 +40,53 @@ from ..normalization import FP32LayerNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SkyReelsV2AttnProcessor2_0:
|
||||
def _get_qkv_projections(
|
||||
attn: "SkyReelsV2Attention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor
|
||||
):
|
||||
# encoder_hidden_states is only passed for cross-attention
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
# In cross-attention layers, we can only fuse the KV projections into a single linear
|
||||
query = attn.to_q(hidden_states)
|
||||
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
return query, key, value
|
||||
|
||||
|
||||
def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: torch.Tensor):
|
||||
if attn.fused_projections:
|
||||
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
|
||||
else:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
return key_img, value_img
|
||||
|
||||
|
||||
class SkyReelsV2AttnProcessor:
|
||||
_attention_backend = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
"SkyReelsV2AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
attn: "SkyReelsV2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
@@ -60,58 +94,66 @@ class SkyReelsV2AttnProcessor2_0:
|
||||
image_context_length = encoder_hidden_states.shape[1] - 512
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
||||
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||||
x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
|
||||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||||
return x_out.type_as(hidden_states)
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
query = apply_rotary_emb(query, *rotary_emb)
|
||||
key = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1))
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states_img = F.scaled_dot_product_attention(
|
||||
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
hidden_states_img = dispatch_attention_fn(
|
||||
query,
|
||||
key_img,
|
||||
value_img,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
@@ -122,7 +164,122 @@ class SkyReelsV2AttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
|
||||
class SkyReelsV2AttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"The SkyReelsV2AttnProcessor2_0 class is deprecated and will be removed in a future version. "
|
||||
"Please use SkyReelsV2AttnProcessor instead. "
|
||||
)
|
||||
deprecate("SkyReelsV2AttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return SkyReelsV2AttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class SkyReelsV2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = SkyReelsV2AttnProcessor
|
||||
_available_processors = [SkyReelsV2AttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-5,
|
||||
dropout: float = 0.0,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
processor=None,
|
||||
is_cross_attention=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
||||
torch.nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.add_k_proj = self.add_v_proj = None
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def fuse_projections(self):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if self.cross_attention_dim_head is None:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_qkv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
|
||||
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_added_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
if hasattr(self, "to_added_kv"):
|
||||
delattr(self, "to_added_kv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
|
||||
|
||||
|
||||
class SkyReelsV2ImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
||||
super().__init__()
|
||||
@@ -213,7 +370,11 @@ class SkyReelsV2TimeTextImageEmbedding(nn.Module):
|
||||
|
||||
class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -223,37 +384,55 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
freqs = []
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq = get_1d_rotary_pos_embed(
|
||||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
|
||||
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
||||
dim,
|
||||
max_seq_len,
|
||||
theta,
|
||||
use_real=True,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
freqs.append(freq)
|
||||
self.freqs = torch.cat(freqs, dim=1)
|
||||
freqs_cos.append(freq_cos)
|
||||
freqs_sin.append(freq_sin)
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
freqs = self.freqs.to(hidden_states.device)
|
||||
freqs = freqs.split_with_sizes(
|
||||
[
|
||||
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
||||
self.attention_head_dim // 6,
|
||||
self.attention_head_dim // 6,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||||
return freqs
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SkyReelsV2TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -269,33 +448,24 @@ class SkyReelsV2TransformerBlock(nn.Module):
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
self.attn1 = SkyReelsV2Attention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
processor=SkyReelsV2AttnProcessor2_0(),
|
||||
cross_attention_dim_head=None,
|
||||
processor=SkyReelsV2AttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
self.attn2 = SkyReelsV2Attention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
added_proj_bias=True,
|
||||
processor=SkyReelsV2AttnProcessor2_0(),
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
processor=SkyReelsV2AttnProcessor(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
@@ -321,15 +491,15 @@ class SkyReelsV2TransformerBlock(nn.Module):
|
||||
# For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
|
||||
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(
|
||||
hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
|
||||
)
|
||||
attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
@@ -338,10 +508,13 @@ class SkyReelsV2TransformerBlock(nn.Module):
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
class SkyReelsV2Transformer3DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
|
||||
|
||||
@@ -389,6 +562,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr
|
||||
_no_split_modules = ["SkyReelsV2TransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["SkyReelsV2TransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user