mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into flux-control-lora-bnb-8bit
This commit is contained in:
@@ -22,17 +22,30 @@
|
||||
|
||||
# Wan2.1
|
||||
|
||||
[Wan2.1](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf) is a series of large diffusion transformer available in two versions, a high-performance 14B parameter model and a more accessible 1.3B version. Trained on billions of images and videos, it supports tasks like text-to-video (T2V) and image-to-video (I2V) while enabling features such as camera control and stylistic diversity. The Wan-VAE features better image data compression and a feature cache mechanism that encodes and decodes a video in chunks. To maintain continuity, features from previous chunks are cached and reused for processing subsequent chunks. This improves inference efficiency by reducing memory usage. Wan2.1 also uses a multilingual text encoder and the diffusion transformer models space and time relationships and text conditions with each time step to capture more complex video dynamics.
|
||||
[Wan-2.1](https://huggingface.co/papers/2503.20314) by the Wan Team.
|
||||
|
||||
*This report presents Wan, a comprehensive and open suite of video foundation models designed to push the boundaries of video generation. Built upon the mainstream diffusion transformer paradigm, Wan achieves significant advancements in generative capabilities through a series of innovations, including our novel VAE, scalable pre-training strategies, large-scale data curation, and automated evaluation metrics. These contributions collectively enhance the model's performance and versatility. Specifically, Wan is characterized by four key features: Leading Performance: The 14B model of Wan, trained on a vast dataset comprising billions of images and videos, demonstrates the scaling laws of video generation with respect to both data and model size. It consistently outperforms the existing open-source models as well as state-of-the-art commercial solutions across multiple internal and external benchmarks, demonstrating a clear and significant performance superiority. Comprehensiveness: Wan offers two capable models, i.e., 1.3B and 14B parameters, for efficiency and effectiveness respectively. It also covers multiple downstream applications, including image-to-video, instruction-guided video editing, and personal video generation, encompassing up to eight tasks. Consumer-Grade Efficiency: The 1.3B model demonstrates exceptional resource efficiency, requiring only 8.19 GB VRAM, making it compatible with a wide range of consumer-grade GPUs. Openness: We open-source the entire series of Wan, including source code and all models, with the goal of fostering the growth of the video generation community. This openness seeks to significantly expand the creative possibilities of video production in the industry and provide academia with high-quality video foundation models. All the code and models are available at [this https URL](https://github.com/Wan-Video/Wan2.1).*
|
||||
|
||||
You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization.
|
||||
|
||||
The following Wan models are supported in Diffusers:
|
||||
- [Wan 2.1 T2V 1.3B](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)
|
||||
- [Wan 2.1 T2V 14B](https://huggingface.co/Wan-AI/Wan2.1-T2V-14B-Diffusers)
|
||||
- [Wan 2.1 I2V 14B - 480P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P-Diffusers)
|
||||
- [Wan 2.1 I2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-720P-Diffusers)
|
||||
- [Wan 2.1 FLF2V 14B - 720P](https://huggingface.co/Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers)
|
||||
- [Wan 2.1 VACE 1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers)
|
||||
- [Wan 2.1 VACE 14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B-diffusers)
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Wan2.1 models in the right sidebar for more examples of video generation.
|
||||
|
||||
### Text-to-Video Generation
|
||||
|
||||
The example below demonstrates how to generate a video from text optimized for memory or inference speed.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="memory">
|
||||
<hfoptions id="T2V usage">
|
||||
<hfoption id="T2V memory">
|
||||
|
||||
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
|
||||
|
||||
@@ -100,7 +113,7 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="inference speed">
|
||||
<hfoption id="T2V inference speed">
|
||||
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster.
|
||||
|
||||
@@ -157,6 +170,81 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### First-Last-Frame-to-Video Generation
|
||||
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
|
||||
|
||||
<hfoptions id="FLF2V usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
|
||||
|
||||
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanImageToVideoPipeline.from_pretrained(
|
||||
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
|
||||
last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
|
||||
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
def center_crop_resize(image, height, width):
|
||||
# Calculate resize ratio to match first frame dimensions
|
||||
resize_ratio = max(width / image.width, height / image.height)
|
||||
|
||||
# Resize the image
|
||||
width = round(image.width * resize_ratio)
|
||||
height = round(image.height * resize_ratio)
|
||||
size = [width, height]
|
||||
image = TF.center_crop(image, size)
|
||||
|
||||
return image, height, width
|
||||
|
||||
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
|
||||
if last_frame.size != first_frame.size:
|
||||
last_frame, _, _ = center_crop_resize(last_frame, height, width)
|
||||
|
||||
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
|
||||
output = pipe(
|
||||
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Any-to-Video Controllable Generation
|
||||
|
||||
Wan VACE supports various generation techniques which achieve controllable video generation. Some of the capabilities include:
|
||||
- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]()
|
||||
- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips)
|
||||
- Inpainting and Outpainting
|
||||
- Subject to Video (faces, object, characters, etc.)
|
||||
- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.)
|
||||
|
||||
The code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals.
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
## Notes
|
||||
|
||||
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
|
||||
@@ -251,6 +339,18 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanVACEPipeline
|
||||
|
||||
[[autodoc]] WanVACEPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanVideoToVideoPipeline
|
||||
|
||||
[[autodoc]] WanVideoToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
205
examples/community/pipeline_stable_diffusion_xl_t5.py
Normal file
205
examples/community/pipeline_stable_diffusion_xl_t5.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright Philip Brown, ppbrown@github
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
###########################################################################
|
||||
# This pipeline attempts to use a model that has SDXL vae, T5 text encoder,
|
||||
# and SDXL unet.
|
||||
# At the present time, there are no pretrained models that give pleasing
|
||||
# output. So as yet, (2025/06/10) this pipeline is somewhat of a tech
|
||||
# demo proving that the pieces can at least be put together.
|
||||
# Hopefully, it will encourage someone with the hardware available to
|
||||
# throw enough resources into training one up.
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
T5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
|
||||
|
||||
# Note: At this time, the intent is to use the T5 encoder mentioned
|
||||
# below, with zero changes.
|
||||
# Therefore, the model deliberately does not store the T5 encoder model bytes,
|
||||
# (Since they are not unique!)
|
||||
# but instead takes advantage of huggingface hub cache loading
|
||||
|
||||
T5_NAME = "mcmonkey/google_t5-v1_1-xxl_encoderonly"
|
||||
|
||||
# Caller is expected to load this, or equivalent, as model name for now
|
||||
# eg: pipe = StableDiffusionXL_T5Pipeline(SDXL_NAME)
|
||||
SDXL_NAME = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
|
||||
class LinearWithDtype(nn.Linear):
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.weight.dtype
|
||||
|
||||
|
||||
class StableDiffusionXL_T5Pipeline(StableDiffusionXLPipeline):
|
||||
_expected_modules = [
|
||||
"vae",
|
||||
"unet",
|
||||
"scheduler",
|
||||
"tokenizer",
|
||||
"image_encoder",
|
||||
"feature_extractor",
|
||||
"t5_encoder",
|
||||
"t5_projection",
|
||||
"t5_pooled_projection",
|
||||
]
|
||||
|
||||
_optional_components = [
|
||||
"image_encoder",
|
||||
"feature_extractor",
|
||||
"t5_encoder",
|
||||
"t5_projection",
|
||||
"t5_pooled_projection",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
tokenizer: CLIPTokenizer,
|
||||
t5_encoder=None,
|
||||
t5_projection=None,
|
||||
t5_pooled_projection=None,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
force_zeros_for_empty_prompt: bool = True,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
):
|
||||
DiffusionPipeline.__init__(self)
|
||||
|
||||
if t5_encoder is None:
|
||||
self.t5_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype)
|
||||
else:
|
||||
self.t5_encoder = t5_encoder
|
||||
|
||||
# ----- build T5 4096 => 2048 dim projection -----
|
||||
if t5_projection is None:
|
||||
self.t5_projection = LinearWithDtype(4096, 2048) # trainable
|
||||
else:
|
||||
self.t5_projection = t5_projection
|
||||
self.t5_projection.to(dtype=unet.dtype)
|
||||
# ----- build T5 4096 => 1280 dim projection -----
|
||||
if t5_pooled_projection is None:
|
||||
self.t5_pooled_projection = LinearWithDtype(4096, 1280) # trainable
|
||||
else:
|
||||
self.t5_pooled_projection = t5_pooled_projection
|
||||
self.t5_pooled_projection.to(dtype=unet.dtype)
|
||||
|
||||
print("dtype of Linear is ", self.t5_projection.dtype)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
tokenizer=tokenizer,
|
||||
t5_encoder=self.t5_encoder,
|
||||
t5_projection=self.t5_projection,
|
||||
t5_pooled_projection=self.t5_pooled_projection,
|
||||
image_encoder=image_encoder,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = (
|
||||
self.unet.config.sample_size
|
||||
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
self.watermark = None
|
||||
|
||||
# Parts of original SDXL class complain if these attributes are not
|
||||
# at least PRESENT
|
||||
self.text_encoder = self.text_encoder_2 = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Encode a text prompt (T5-XXL + 4096→2048 projection)
|
||||
# Returns exactly four tensors in the order SDXL’s __call__ expects.
|
||||
# ------------------------------------------------------------------
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str | None = None,
|
||||
**_,
|
||||
):
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
prompt_embeds : Tensor [B, T, 2048]
|
||||
negative_prompt_embeds : Tensor [B, T, 2048] | None
|
||||
pooled_prompt_embeds : Tensor [B, 1280]
|
||||
negative_pooled_prompt_embeds: Tensor [B, 1280] | None
|
||||
where B = batch * num_images_per_prompt
|
||||
"""
|
||||
|
||||
# --- helper to tokenize on the pipeline’s device ----------------
|
||||
def _tok(text: str):
|
||||
tok_out = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
).to(self.device)
|
||||
return tok_out.input_ids, tok_out.attention_mask
|
||||
|
||||
# ---------- positive stream -------------------------------------
|
||||
ids, mask = _tok(prompt)
|
||||
h_pos = self.t5_encoder(ids, attention_mask=mask).last_hidden_state # [b, T, 4096]
|
||||
tok_pos = self.t5_projection(h_pos) # [b, T, 2048]
|
||||
pool_pos = self.t5_pooled_projection(h_pos.mean(dim=1)) # [b, 1280]
|
||||
|
||||
# expand for multiple images per prompt
|
||||
tok_pos = tok_pos.repeat_interleave(num_images_per_prompt, 0)
|
||||
pool_pos = pool_pos.repeat_interleave(num_images_per_prompt, 0)
|
||||
|
||||
# ---------- negative / CFG stream --------------------------------
|
||||
if do_classifier_free_guidance:
|
||||
neg_text = "" if negative_prompt is None else negative_prompt
|
||||
ids_n, mask_n = _tok(neg_text)
|
||||
h_neg = self.t5_encoder(ids_n, attention_mask=mask_n).last_hidden_state
|
||||
tok_neg = self.t5_projection(h_neg)
|
||||
pool_neg = self.t5_pooled_projection(h_neg.mean(dim=1))
|
||||
|
||||
tok_neg = tok_neg.repeat_interleave(num_images_per_prompt, 0)
|
||||
pool_neg = pool_neg.repeat_interleave(num_images_per_prompt, 0)
|
||||
else:
|
||||
tok_neg = pool_neg = None
|
||||
|
||||
# ----------------- final ordered return --------------------------
|
||||
# 1) positive token embeddings
|
||||
# 2) negative token embeddings (or None)
|
||||
# 3) positive pooled embeddings
|
||||
# 4) negative pooled embeddings (or None)
|
||||
return tok_pos, tok_neg, pool_pos, pool_neg
|
||||
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
@@ -14,6 +14,8 @@ from diffusers import (
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
WanVACEPipeline,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +61,52 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
}
|
||||
|
||||
VACE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# # For the I2V model
|
||||
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# # for the FLF2V model
|
||||
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
"before_proj": "proj_in",
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
@@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
if model_type == "Wan-T2V-1.3B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
|
||||
@@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-T2V-14B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
|
||||
@@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-I2V-14B-480p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
|
||||
@@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
|
||||
@@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-FLF2V-14B-720P":
|
||||
config = {
|
||||
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
|
||||
@@ -175,11 +230,60 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"pos_embed_seq_len": 257 * 2,
|
||||
},
|
||||
}
|
||||
return config
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-VACE-1.3B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.1-VACE-1.3B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 12,
|
||||
"num_layers": 30,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-VACE-14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.1-VACE-14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
def convert_transformer(model_type: str):
|
||||
config = get_transformer_config(model_type)
|
||||
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
model_id = config["model_id"]
|
||||
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
||||
@@ -187,16 +291,19 @@ def convert_transformer(model_type: str):
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
if "VACE" not in model_type:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
for replace_key, rename_key in RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
@@ -412,7 +519,7 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--dtype", default="fp32")
|
||||
parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -426,18 +533,20 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
transformer = convert_transformer(args.model_type).to(dtype=dtype)
|
||||
transformer = convert_transformer(args.model_type)
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
|
||||
)
|
||||
|
||||
# If user has specified "none", we keep the original dtypes of the state dict without any conversion
|
||||
if args.dtype != "none":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
|
||||
if "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
@@ -452,6 +561,14 @@ if __name__ == "__main__":
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
elif "VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
|
||||
@@ -215,6 +215,7 @@ else:
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
@@ -527,6 +528,7 @@ else:
|
||||
"VQDiffusionPipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WanVACEPipeline",
|
||||
"WanVideoToVideoPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
@@ -821,6 +823,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
@@ -1112,6 +1115,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQDiffusionPipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanVACEPipeline,
|
||||
WanVideoToVideoPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
@@ -58,6 +58,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -178,6 +179,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
@@ -32,3 +32,4 @@ if is_torch_available():
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
|
||||
@@ -241,7 +241,7 @@ class FluxTransformer2DModel(
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
|
||||
393
src/diffusers/models/transformers/transformer_wan_vace.py
Normal file
393
src/diffusers/models/transformers/transformer_wan_vace.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanVACETransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
apply_input_projection: bool = False,
|
||||
apply_output_projection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Input projection
|
||||
self.proj_in = None
|
||||
if apply_input_projection:
|
||||
self.proj_in = nn.Linear(dim, dim)
|
||||
|
||||
# 2. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Cross-attention
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_heads,
|
||||
kv_heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
added_proj_bias=True,
|
||||
processor=WanAttnProcessor2_0(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 4. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
# 5. Output projection
|
||||
self.proj_out = None
|
||||
if apply_output_projection:
|
||||
self.proj_out = nn.Linear(dim, dim)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
control_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.proj_in is not None:
|
||||
control_hidden_states = self.proj_in(control_hidden_states)
|
||||
control_hidden_states = control_hidden_states + hidden_states
|
||||
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
control_hidden_states = control_hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
|
||||
control_hidden_states
|
||||
)
|
||||
|
||||
conditioning_states = None
|
||||
if self.proj_out is not None:
|
||||
conditioning_states = self.proj_out(control_hidden_states)
|
||||
|
||||
return conditioning_states, control_hidden_states
|
||||
|
||||
|
||||
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Wan model.
|
||||
|
||||
Args:
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
vace_layers: List[int] = [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
vace_in_channels: int = 96,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
if max(vace_layers) >= num_layers:
|
||||
raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.")
|
||||
if 0 not in vace_layers:
|
||||
raise ValueError("VACE layers must include layer 0.")
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Condition embeddings
|
||||
# image_embedding_dim=1280 for I2V model
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.vace_blocks = nn.ModuleList(
|
||||
[
|
||||
WanVACETransformerBlock(
|
||||
inner_dim,
|
||||
ffn_dim,
|
||||
num_attention_heads,
|
||||
qk_norm,
|
||||
cross_attn_norm,
|
||||
eps,
|
||||
added_kv_proj_dim,
|
||||
apply_input_projection=i == 0, # Layer 0 always has input projection and is in vace_layers
|
||||
apply_output_projection=True,
|
||||
)
|
||||
for i in range(len(vace_layers))
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
control_hidden_states: torch.Tensor = None,
|
||||
control_hidden_states_scale: torch.Tensor = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
if control_hidden_states_scale is None:
|
||||
control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
|
||||
control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
|
||||
if len(control_hidden_states_scale) != len(self.config.vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
|
||||
f"equal to {len(self.config.vace_layers)}."
|
||||
)
|
||||
|
||||
# 1. Rotary position embedding
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
|
||||
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
|
||||
control_hidden_states_padding = control_hidden_states.new_zeros(
|
||||
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
|
||||
)
|
||||
control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1)
|
||||
|
||||
# 3. Time embedding
|
||||
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image
|
||||
)
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
|
||||
# 4. Image embedding
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
|
||||
# 5. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
else:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
for i, block in enumerate(self.vace_blocks):
|
||||
conditioning_states, control_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
||||
control_hidden_states_list = control_hidden_states_list[::-1]
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
|
||||
# 6. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
# first device rather than the last device, which hidden_states ends up
|
||||
# on.
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -371,7 +371,7 @@ else:
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -739,7 +739,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UniDiffuserTextDecoder,
|
||||
)
|
||||
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
@@ -21,7 +21,7 @@ from ...image_processor import VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -47,7 +47,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AmusedPipeline(DiffusionPipeline):
|
||||
class AmusedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.33.1"
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
@@ -21,7 +21,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -57,7 +57,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AmusedImg2ImgPipeline(DiffusionPipeline):
|
||||
class AmusedImg2ImgPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.33.1"
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
@@ -22,7 +22,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models import UVit2DModel, VQModel
|
||||
from ...schedulers import AmusedScheduler
|
||||
from ...utils import is_torch_xla_available, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -65,7 +65,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AmusedInpaintPipeline(DiffusionPipeline):
|
||||
class AmusedInpaintPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.33.1"
|
||||
image_processor: VaeImageProcessor
|
||||
vqvae: VQModel
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -57,7 +57,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class AudioLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using AudioLDM.
|
||||
|
||||
@@ -81,6 +81,7 @@ class AudioLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
Vocoder of class `SpeechT5HifiGan`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .blip_image_processing import BlipImageProcessor
|
||||
from .modeling_blip2 import Blip2QFormerModel
|
||||
from .modeling_ctx_clip import ContextCLIPTextModel
|
||||
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
class BlipDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
|
||||
|
||||
@@ -107,6 +107,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -98,6 +98,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
class StableDiffusionControlNetXSPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
@@ -138,6 +139,7 @@ class StableDiffusionControlNetXSPipeline(
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -46,7 +46,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetXSPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
@@ -158,6 +159,7 @@ class StableDiffusionXLControlNetXSPipeline(
|
||||
watermarker is used.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
|
||||
_optional_components = [
|
||||
"tokenizer",
|
||||
|
||||
@@ -21,7 +21,7 @@ from ...models import UNet1DModel
|
||||
from ...schedulers import SchedulerMixin
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -34,7 +34,7 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
class DanceDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for audio generation.
|
||||
|
||||
@@ -49,6 +49,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
|
||||
[`IPNDMScheduler`].
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
model_cpu_offload_seq = "unet"
|
||||
|
||||
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -97,9 +97,11 @@ class I2VGenXLPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class I2VGenXLPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for image-to-video generation as proposed in [I2VGenXL](https://i2vgen-xl.github.io/).
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import AudioPipelineOutput, DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_librosa_available():
|
||||
@@ -76,7 +76,8 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class MusicLDMPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class MusicLDMPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-audio generation using MusicLDM.
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from .image_encoder import PaintByExampleImageEncoder
|
||||
@@ -155,7 +155,8 @@ def prepare_mask_and_masked_image(image, mask):
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class PaintByExamplePipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
<Tip warning={true}>
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from ...utils import (
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..free_init_utils import FreeInitMixin
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -132,6 +132,7 @@ class PIAPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class PIAPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
@@ -140,6 +141,7 @@ class PIAPipeline(
|
||||
FromSingleFileMixin,
|
||||
FreeInitMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
|
||||
@@ -139,6 +139,43 @@ class AudioPipelineOutput(BaseOutput):
|
||||
audios: np.ndarray
|
||||
|
||||
|
||||
class DeprecatedPipelineMixin:
|
||||
"""
|
||||
A mixin that can be used to mark a pipeline as deprecated.
|
||||
|
||||
Pipelines inheriting from this mixin will raise a warning when instantiated, indicating that they are deprecated
|
||||
and won't receive updates past the specified version. Tests will be skipped for pipelines that inherit from this
|
||||
mixin.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
class MyDeprecatedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
_last_supported_version = "0.20.0"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
```
|
||||
"""
|
||||
|
||||
# Override this in the inheriting class to specify the last version that will support this pipeline
|
||||
_last_supported_version = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Get the class name for the warning message
|
||||
class_name = self.__class__.__name__
|
||||
|
||||
# Get the last supported version or use the current version if not specified
|
||||
version_info = getattr(self.__class__, "_last_supported_version", __version__)
|
||||
|
||||
# Raise a warning that this pipeline is deprecated
|
||||
logger.warning(
|
||||
f"The {class_name} has been deprecated and will not receive bug fixes or feature updates after Diffusers version {version_info}. "
|
||||
)
|
||||
|
||||
# Call the parent class's __init__ method
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all pipelines.
|
||||
|
||||
@@ -11,7 +11,7 @@ from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyCh
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from .pipeline_output import SemanticStableDiffusionPipelineOutput
|
||||
|
||||
|
||||
@@ -25,7 +25,8 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class SemanticStableDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with latent editing.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -179,7 +179,9 @@ class AttendExciteAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin):
|
||||
class StableDiffusionAttendAndExcitePipeline(
|
||||
DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion and Attend-and-Excite.
|
||||
|
||||
@@ -209,6 +211,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -40,7 +40,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -242,7 +242,11 @@ def preprocess_mask(mask, batch_size: int = 1):
|
||||
|
||||
|
||||
class StableDiffusionDiffEditPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
<Tip warning={true}>
|
||||
@@ -282,6 +286,8 @@ class StableDiffusionDiffEditPipeline(
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -36,7 +36,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -108,7 +108,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class StableDiffusionGLIGENPipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN).
|
||||
|
||||
@@ -135,6 +135,8 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -41,7 +41,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -160,7 +160,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionMixin):
|
||||
class StableDiffusionGLIGENTextImagePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with Grounded-Language-to-Image Generation (GLIGEN).
|
||||
|
||||
@@ -193,6 +193,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
|
||||
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -42,7 +42,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -64,7 +64,11 @@ class ModelWrapper:
|
||||
|
||||
|
||||
class StableDiffusionKDiffusionPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@@ -105,6 +109,8 @@ class StableDiffusionKDiffusionPipeline(
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
@@ -48,7 +48,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
|
||||
@@ -88,6 +88,7 @@ class ModelWrapper:
|
||||
|
||||
|
||||
class StableDiffusionXLKDiffusionPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
FromSingleFileMixin,
|
||||
@@ -95,6 +96,8 @@ class StableDiffusionXLKDiffusionPipeline(
|
||||
TextualInversionLoaderMixin,
|
||||
IPAdapterMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL and k-diffusion.
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -178,6 +178,7 @@ class LDM3DPipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class StableDiffusionLDM3DPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
@@ -185,6 +186,8 @@ class StableDiffusionLDM3DPipeline(
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image and 3D generation using LDM3D.
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -156,12 +156,15 @@ def retrieve_timesteps(
|
||||
|
||||
|
||||
class StableDiffusionPanoramaPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
IPAdapterMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image generation using MultiDiffusion.
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from . import StableDiffusionSafePipelineOutput
|
||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||
|
||||
@@ -29,7 +29,9 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin):
|
||||
class StableDiffusionPipelineSafe(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, IPAdapterMixin):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline based on the [`StableDiffusionPipeline`] for text-to-image generation using Safe Latent Diffusion.
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -107,7 +107,11 @@ class CrossAttnStoreProcessor:
|
||||
|
||||
|
||||
# Modified to get self-attention guidance scale in this paper (https://huggingface.co/papers/2210.00939) as an input
|
||||
class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin):
|
||||
class StableDiffusionSAGPipeline(
|
||||
DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
|
||||
@@ -68,8 +68,13 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from . import TextToVideoSDPipelineOutput
|
||||
|
||||
|
||||
@@ -103,8 +103,13 @@ def retrieve_latents(
|
||||
|
||||
|
||||
class VideoToVideoSDPipeline(
|
||||
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for text-guided video-to-video generation.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -296,12 +296,14 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
|
||||
|
||||
|
||||
class TextToVideoZeroPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...utils import (
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin
|
||||
|
||||
|
||||
if is_invisible_watermark_available():
|
||||
@@ -346,11 +346,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
|
||||
|
||||
class TextToVideoZeroSDXLPipeline(
|
||||
DeprecatedPipelineMixin,
|
||||
DiffusionPipeline,
|
||||
StableDiffusionMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
_last_supported_version = "0.33.1"
|
||||
r"""
|
||||
Pipeline for zero-shot text-to-video generation using Stable Diffusion XL.
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class UnCLIPPipeline(DiffusionPipeline):
|
||||
class UnCLIPPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for text-to-image generation using unCLIP.
|
||||
|
||||
@@ -69,6 +69,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
||||
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
_exclude_from_cpu_offload = ["prior"]
|
||||
|
||||
prior: PriorTransformer
|
||||
|
||||
@@ -29,7 +29,7 @@ from ...models import UNet2DConditionModel, UNet2DModel
|
||||
from ...schedulers import UnCLIPScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
class UnCLIPImageVariationPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline to generate image variations from an input image using UnCLIP.
|
||||
|
||||
@@ -73,6 +73,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
|
||||
Scheduler used in the super resolution denoising process (a modified [`DDPMScheduler`]).
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
decoder: UNet2DConditionModel
|
||||
text_proj: UnCLIPTextProjModel
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
|
||||
@@ -28,7 +28,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.outputs import BaseOutput
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
from .modeling_text_decoder import UniDiffuserTextDecoder
|
||||
from .modeling_uvit import UniDiffuserModel
|
||||
|
||||
@@ -62,7 +62,7 @@ class ImageTextPipelineOutput(BaseOutput):
|
||||
text: Optional[Union[List[str], List[List[str]]]]
|
||||
|
||||
|
||||
class UniDiffuserPipeline(DiffusionPipeline):
|
||||
class UniDiffuserPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for a bimodal image-text model which supports unconditional text and image generation, text-conditioned
|
||||
image generation, image-conditioned text generation, and joint image-text generation.
|
||||
@@ -96,6 +96,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
|
||||
original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
# TODO: support for moving submodules for components with enable_model_cpu_offload
|
||||
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae->text_decoder"
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_wan"] = ["WanPipeline"]
|
||||
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
|
||||
_import_structure["pipeline_wan_vace"] = ["WanVACEPipeline"]
|
||||
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_wan import WanPipeline
|
||||
from .pipeline_wan_i2v import WanImageToVideoPipeline
|
||||
from .pipeline_wan_vace import WanVACEPipeline
|
||||
from .pipeline_wan_video2video import WanVideoToVideoPipeline
|
||||
|
||||
else:
|
||||
|
||||
950
src/diffusers/pipelines/wan/pipeline_wan_vace.py
Normal file
950
src/diffusers/pipelines/wan/pipeline_wan_vace.py
Normal file
@@ -0,0 +1,950 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanVACETransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import WanPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> import PIL.Image
|
||||
>>> from diffusers import AutoencoderKLWan, WanVACEPipeline
|
||||
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
>>> from diffusers.utils import export_to_video, load_image
|
||||
def prepare_video_and_mask(first_img: PIL.Image.Image, last_img: PIL.Image.Image, height: int, width: int, num_frames: int):
|
||||
first_img = first_img.resize((width, height))
|
||||
last_img = last_img.resize((width, height))
|
||||
frames = []
|
||||
frames.append(first_img)
|
||||
# Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
|
||||
# whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
|
||||
# match the original code.
|
||||
frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
|
||||
frames.append(last_img)
|
||||
mask_black = PIL.Image.new("L", (width, height), 0)
|
||||
mask_white = PIL.Image.new("L", (width, height), 255)
|
||||
mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
|
||||
return frames, mask
|
||||
|
||||
>>> # Available checkpoints: Wan-AI/Wan2.1-VACE-1.3B-diffusers, Wan-AI/Wan2.1-VACE-14B-diffusers
|
||||
>>> model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
>>> first_frame = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
|
||||
... )
|
||||
>>> last_frame = load_image(
|
||||
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png>>> "
|
||||
... )
|
||||
|
||||
>>> height = 512
|
||||
>>> width = 512
|
||||
>>> num_frames = 81
|
||||
>>> video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)
|
||||
|
||||
>>> output = pipe(
|
||||
... video=video,
|
||||
... mask=mask,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=height,
|
||||
... width=width,
|
||||
... num_frames=num_frames,
|
||||
... num_inference_steps=30,
|
||||
... guidance_scale=5.0,
|
||||
... generator=torch.Generator().manual_seed(42),
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for controllable generation using Wan.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanVACETransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
video=None,
|
||||
mask=None,
|
||||
reference_images=None,
|
||||
):
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
if height % base != 0 or width % base != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if video is not None:
|
||||
if mask is not None:
|
||||
if len(video) != len(mask):
|
||||
raise ValueError(
|
||||
f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that"
|
||||
" they have the same length."
|
||||
)
|
||||
if reference_images is not None:
|
||||
is_pil_image = isinstance(reference_images, PIL.Image.Image)
|
||||
is_list_of_pil_images = isinstance(reference_images, list) and all(
|
||||
isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images
|
||||
)
|
||||
is_list_of_list_of_pil_images = isinstance(reference_images, list) and all(
|
||||
isinstance(ref_img, list) and all(isinstance(ref_img_, PIL.Image.Image) for ref_img_ in ref_img)
|
||||
for ref_img in reference_images
|
||||
)
|
||||
if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images):
|
||||
raise ValueError(
|
||||
"`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
|
||||
"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
|
||||
)
|
||||
if is_list_of_list_of_pil_images and len(reference_images) != 1:
|
||||
raise ValueError(
|
||||
"The pipeline only supports generating one video at a time at the moment. When passing a list "
|
||||
"of list of reference images, where the outer list corresponds to the batch size and the inner "
|
||||
"list corresponds to list of conditioning images per video, please make sure to only pass "
|
||||
"one inner list of reference images (i.e., `[[<image1>, <image2>, ...]]`"
|
||||
)
|
||||
elif mask is not None:
|
||||
raise ValueError("`mask` can only be passed if `video` is passed as well.")
|
||||
|
||||
def preprocess_conditions(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
mask: Optional[List[PipelineImageInput]] = None,
|
||||
reference_images: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
|
||||
batch_size: int = 1,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
if video is not None:
|
||||
base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
|
||||
video_height, video_width = self.video_processor.get_default_height_width(video[0])
|
||||
|
||||
if video_height * video_width > height * width:
|
||||
scale = min(width / video_width, height / video_height)
|
||||
video_height, video_width = int(video_height * scale), int(video_width * scale)
|
||||
|
||||
if video_height % base != 0 or video_width % base != 0:
|
||||
logger.warning(
|
||||
f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}. "
|
||||
)
|
||||
video_height = (video_height // base) * base
|
||||
video_width = (video_width // base) * base
|
||||
|
||||
assert video_height * video_width <= height * width
|
||||
|
||||
video = self.video_processor.preprocess_video(video, video_height, video_width)
|
||||
image_size = (video_height, video_width) # Use the height/width of video (with possible rescaling)
|
||||
else:
|
||||
video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device)
|
||||
image_size = (height, width) # Use the height/width provider by user
|
||||
|
||||
if mask is not None:
|
||||
mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1])
|
||||
mask = torch.clamp((mask + 1) / 2, min=0, max=1)
|
||||
else:
|
||||
mask = torch.ones_like(video)
|
||||
|
||||
video = video.to(dtype=dtype, device=device)
|
||||
mask = mask.to(dtype=dtype, device=device)
|
||||
|
||||
# Make a list of list of images where the outer list corresponds to video batch size and the inner list
|
||||
# corresponds to list of conditioning images per video
|
||||
if reference_images is None or isinstance(reference_images, PIL.Image.Image):
|
||||
reference_images = [[reference_images] for _ in range(video.shape[0])]
|
||||
elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image):
|
||||
reference_images = [reference_images]
|
||||
elif (
|
||||
isinstance(reference_images, (list, tuple))
|
||||
and isinstance(next(iter(reference_images)), list)
|
||||
and isinstance(next(iter(reference_images[0])), PIL.Image.Image)
|
||||
):
|
||||
reference_images = reference_images
|
||||
else:
|
||||
raise ValueError(
|
||||
"`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or "
|
||||
"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}"
|
||||
)
|
||||
|
||||
if video.shape[0] != len(reference_images):
|
||||
raise ValueError(
|
||||
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
ref_images_lengths = [len(reference_images_batch) for reference_images_batch in reference_images]
|
||||
if any(l != ref_images_lengths[0] for l in ref_images_lengths):
|
||||
raise ValueError(
|
||||
f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}. Support for this "
|
||||
"may be added in the future."
|
||||
)
|
||||
|
||||
reference_images_preprocessed = []
|
||||
for i, reference_images_batch in enumerate(reference_images):
|
||||
preprocessed_images = []
|
||||
for j, image in enumerate(reference_images_batch):
|
||||
if image is None:
|
||||
continue
|
||||
image = self.video_processor.preprocess(image, None, None)
|
||||
img_height, img_width = image.shape[-2:]
|
||||
scale = min(image_size[0] / img_height, image_size[1] / img_width)
|
||||
new_height, new_width = int(img_height * scale), int(img_width * scale)
|
||||
resized_image = torch.nn.functional.interpolate(
|
||||
image, size=(new_height, new_width), mode="bilinear", align_corners=False
|
||||
).squeeze(0) # [C, H, W]
|
||||
top = (image_size[0] - new_height) // 2
|
||||
left = (image_size[1] - new_width) // 2
|
||||
canvas = torch.ones(3, *image_size, device=device, dtype=dtype)
|
||||
canvas[:, top : top + new_height, left : left + new_width] = resized_image
|
||||
preprocessed_images.append(canvas)
|
||||
reference_images_preprocessed.append(preprocessed_images)
|
||||
|
||||
return video, mask, reference_images_preprocessed
|
||||
|
||||
def prepare_video_latents(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
reference_images: Optional[List[List[torch.Tensor]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
device = device or self._execution_device
|
||||
|
||||
if isinstance(generator, list):
|
||||
# TODO: support this
|
||||
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
|
||||
|
||||
if reference_images is None:
|
||||
# For each batch of video, we set no re
|
||||
# ference image (as one or more can be passed by user)
|
||||
reference_images = [[None] for _ in range(video.shape[0])]
|
||||
else:
|
||||
if video.shape[0] != len(reference_images):
|
||||
raise ValueError(
|
||||
f"Batch size of `video` {video.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
if video.shape[0] != 1:
|
||||
# TODO: support this
|
||||
raise ValueError(
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
video = video.to(dtype=vae_dtype)
|
||||
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view(
|
||||
1, self.vae.config.z_dim, 1, 1, 1
|
||||
)
|
||||
|
||||
if mask is None:
|
||||
latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
|
||||
latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
else:
|
||||
mask = mask.to(dtype=vae_dtype)
|
||||
mask = torch.where(mask > 0.5, 1.0, 0.0)
|
||||
inactive = video * (1 - mask)
|
||||
reactive = video * mask
|
||||
inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")
|
||||
reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax")
|
||||
inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
latents = torch.cat([inactive, reactive], dim=1)
|
||||
|
||||
latent_list = []
|
||||
for latent, reference_images_batch in zip(latents, reference_images):
|
||||
for reference_image in reference_images_batch:
|
||||
assert reference_image.ndim == 3
|
||||
reference_image = reference_image.to(dtype=vae_dtype)
|
||||
reference_image = reference_image[None, :, None, :, :] # [1, C, 1, H, W]
|
||||
reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax")
|
||||
reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype)
|
||||
reference_latent = reference_latent.squeeze(0) # [C, 1, H, W]
|
||||
reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=0)
|
||||
latent = torch.cat([reference_latent.squeeze(0), latent], dim=1)
|
||||
latent_list.append(latent)
|
||||
return torch.stack(latent_list)
|
||||
|
||||
def prepare_masks(
|
||||
self,
|
||||
mask: torch.Tensor,
|
||||
reference_images: Optional[List[torch.Tensor]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list):
|
||||
# TODO: support this
|
||||
raise ValueError("Passing a list of generators is not yet supported. This may be supported in the future.")
|
||||
|
||||
if reference_images is None:
|
||||
# For each batch of video, we set no reference image (as one or more can be passed by user)
|
||||
reference_images = [[None] for _ in range(mask.shape[0])]
|
||||
else:
|
||||
if mask.shape[0] != len(reference_images):
|
||||
raise ValueError(
|
||||
f"Batch size of `mask` {mask.shape[0]} and length of `reference_images` {len(reference_images)} does not match."
|
||||
)
|
||||
|
||||
if mask.shape[0] != 1:
|
||||
# TODO: support this
|
||||
raise ValueError(
|
||||
"Generating with more than one video is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
transformer_patch_size = self.transformer.config.patch_size[1]
|
||||
|
||||
mask_list = []
|
||||
for mask_, reference_images_batch in zip(mask, reference_images):
|
||||
num_channels, num_frames, height, width = mask_.shape
|
||||
new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal
|
||||
new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
|
||||
new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size
|
||||
mask_ = mask_[0, :, :, :]
|
||||
mask_ = mask_.view(
|
||||
num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial
|
||||
)
|
||||
mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) # [8x8, num_frames, new_height, new_width]
|
||||
mask_ = torch.nn.functional.interpolate(
|
||||
mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact"
|
||||
).squeeze(0)
|
||||
num_ref_images = len(reference_images_batch)
|
||||
if num_ref_images > 0:
|
||||
mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :])
|
||||
mask_ = torch.cat([mask_, mask_padding], dim=1)
|
||||
mask_list.append(mask_)
|
||||
return torch.stack(mask_list)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
mask: Optional[List[PipelineImageInput]] = None,
|
||||
reference_images: Optional[List[PipelineImageInput]] = None,
|
||||
conditioning_scale: Union[float, List[float], torch.Tensor] = 1.0,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"np"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~WanPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# Simplification of implementation for now
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.")
|
||||
if num_videos_per_prompt != 1:
|
||||
raise ValueError(
|
||||
"Generating multiple videos per prompt is not yet supported. This may be supported in the future."
|
||||
)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
video,
|
||||
mask,
|
||||
reference_images,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
if isinstance(conditioning_scale, (int, float)):
|
||||
conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
|
||||
if isinstance(conditioning_scale, list):
|
||||
if len(conditioning_scale) != len(self.transformer.config.vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
)
|
||||
conditioning_scale = torch.tensor(conditioning_scale)
|
||||
if isinstance(conditioning_scale, torch.Tensor):
|
||||
if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
|
||||
raise ValueError(
|
||||
f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
|
||||
)
|
||||
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
video, mask, reference_images = self.preprocess_conditions(
|
||||
video,
|
||||
mask,
|
||||
reference_images,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
)
|
||||
num_reference_images = len(reference_images[0])
|
||||
|
||||
conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device)
|
||||
mask = self.prepare_masks(mask, reference_images, generator)
|
||||
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
|
||||
conditioning_latents = conditioning_latents.to(transformer_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames + num_reference_images * self.vae_scale_factor_temporal,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
if conditioning_latents.shape[2] != latents.shape[2]:
|
||||
logger.warning(
|
||||
"The number of frames in the conditioning latents does not match the number of frames to be generated. Generation quality may be affected."
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
control_hidden_states=conditioning_latents,
|
||||
control_hidden_states_scale=conditioning_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents[:, :, num_reference_images:]
|
||||
latents = latents.to(vae_dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
@@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
|
||||
from .modeling_paella_vq_model import PaellaVQModel
|
||||
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
|
||||
|
||||
@@ -56,7 +56,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class WuerstchenDecoderPipeline(DiffusionPipeline):
|
||||
class WuerstchenDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for generating images from the Wuerstchen model.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import deprecate, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
|
||||
from .modeling_paella_vq_model import PaellaVQModel
|
||||
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
|
||||
from .modeling_wuerstchen_prior import WuerstchenPrior
|
||||
@@ -40,7 +40,7 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
class WuerstchenCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
|
||||
"""
|
||||
Combined Pipeline for text-to-image generation using Wuerstchen
|
||||
|
||||
@@ -68,6 +68,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
|
||||
The scheduler to be used for prior pipeline.
|
||||
"""
|
||||
|
||||
_last_supported_version = "0.33.1"
|
||||
_load_connected_pipes = True
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -1150,6 +1150,21 @@ class WanTransformer3DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class WanVACETransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def get_constant_schedule(*args, **kwargs):
|
||||
requires_backends(get_constant_schedule, ["torch"])
|
||||
|
||||
|
||||
@@ -2897,6 +2897,21 @@ class WanPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanVACEPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanVideoToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1744,6 +1744,10 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
class TorchCompileTesterMixin:
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
@@ -1759,12 +1763,7 @@ class TorchCompileTesterMixin:
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch.compiler.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -1778,6 +1777,31 @@ class TorchCompileTesterMixin:
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
if not getattr(model, "_supports_group_offloading", True):
|
||||
return
|
||||
|
||||
model.eval()
|
||||
# TODO: Can test for other group offloading kwargs later if needed.
|
||||
group_offload_kwargs = {
|
||||
"onload_device": "cuda",
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_2
|
||||
|
||||
@@ -78,9 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class FluxTransformerTests(
|
||||
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, unittest.TestCase
|
||||
):
|
||||
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
@@ -169,3 +167,17 @@ class FluxTransformerTests(
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -93,7 +93,14 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin,
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -161,7 +168,14 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompi
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -227,9 +241,14 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, TorchCompileT
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
|
||||
ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase
|
||||
):
|
||||
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -295,3 +314,10 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HunyuanVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = HunyuanVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -26,7 +26,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class LTXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = LTXVideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -81,3 +81,10 @@ class LTXTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.Te
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"LTXVideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class LTXTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = LTXVideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return LTXTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
|
||||
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
@@ -82,3 +82,10 @@ class WanTransformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -355,9 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(
|
||||
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
|
||||
):
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
@@ -1147,6 +1145,20 @@ class UNet2DConditionModelTests(
|
||||
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
|
||||
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
|
||||
@@ -1115,6 +1115,14 @@ class PipelineTesterMixin:
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Skip tests for pipelines that inherit from DeprecatedPipelineMixin
|
||||
from diffusers.pipelines.pipeline_utils import DeprecatedPipelineMixin
|
||||
|
||||
if hasattr(self, "pipeline_class") and issubclass(self.pipeline_class, DeprecatedPipelineMixin):
|
||||
import pytest
|
||||
|
||||
pytest.skip(reason=f"Deprecated Pipeline: {self.pipeline_class.__name__}")
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test in case of CUDA runtime errors
|
||||
super().tearDown()
|
||||
|
||||
201
tests/pipelines/wan/test_wan_vace.py
Normal file
201
tests/pipelines/wan/test_wan_vace.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = WanVACEPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLWan(
|
||||
base_dim=3,
|
||||
z_dim=16,
|
||||
dim_mult=[1, 1, 1, 1],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True, True],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = WanVACETransformer3DModel(
|
||||
patch_size=(1, 2, 2),
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=12,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
text_dim=32,
|
||||
freq_dim=256,
|
||||
ffn_dim=32,
|
||||
num_layers=3,
|
||||
cross_attn_norm=True,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
rope_max_seq_len=32,
|
||||
vace_layers=[0, 2],
|
||||
vace_in_channels=96,
|
||||
)
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
num_frames = 17
|
||||
height = 16
|
||||
width = 16
|
||||
|
||||
video = [Image.new("RGB", (height, width))] * num_frames
|
||||
mask = [Image.new("L", (height, width), 0)] * num_frames
|
||||
|
||||
inputs = {
|
||||
"video": video,
|
||||
"mask": mask,
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "negative",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"num_frames": num_frames,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames[0]
|
||||
self.assertEqual(video.shape, (17, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = [0.4523, 0.45198, 0.44872, 0.45326, 0.45211, 0.45258, 0.45344, 0.453, 0.52431, 0.52572, 0.50701, 0.5118, 0.53717, 0.53093, 0.50557, 0.51402]
|
||||
# fmt: on
|
||||
|
||||
video_slice = video.flatten()
|
||||
video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
|
||||
video_slice = [round(x, 5) for x in video_slice.tolist()]
|
||||
self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
|
||||
|
||||
def test_inference_with_single_reference_image(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["reference_images"] = Image.new("RGB", (16, 16))
|
||||
video = pipe(**inputs).frames[0]
|
||||
self.assertEqual(video.shape, (17, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = [0.45247, 0.45214, 0.44874, 0.45314, 0.45171, 0.45299, 0.45428, 0.45317, 0.51378, 0.52658, 0.53361, 0.52303, 0.46204, 0.50435, 0.52555, 0.51342]
|
||||
# fmt: on
|
||||
|
||||
video_slice = video.flatten()
|
||||
video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
|
||||
video_slice = [round(x, 5) for x in video_slice.tolist()]
|
||||
self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
|
||||
|
||||
def test_inference_with_multiple_reference_image(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["reference_images"] = [[Image.new("RGB", (16, 16))] * 2]
|
||||
video = pipe(**inputs).frames[0]
|
||||
self.assertEqual(video.shape, (17, 3, 16, 16))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = [0.45321, 0.45221, 0.44818, 0.45375, 0.45268, 0.4519, 0.45271, 0.45253, 0.51244, 0.52223, 0.51253, 0.51321, 0.50743, 0.51177, 0.51626, 0.50983]
|
||||
# fmt: on
|
||||
|
||||
video_slice = video.flatten()
|
||||
video_slice = torch.cat([video_slice[:8], video_slice[-8:]])
|
||||
video_slice = [round(x, 5) for x in video_slice.tolist()]
|
||||
self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Batching is not yet supported with this pipeline")
|
||||
def test_inference_batch_consistent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Batching is not yet supported with this pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
return super().test_inference_batch_single_identical()
|
||||
Reference in New Issue
Block a user