1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[modular] wan! (#12611)

* update, remove intermediaate_inputs

* support image2video

* revert dynamic steps to simplify

* refactor vae encoder block

* support flf2video!

* add support for wan2.2 14B

* style

* Apply suggestions from code review

* input dynamic step -> additiional input step

* up

* fix init

* update dtype
This commit is contained in:
YiYi Xu
2025-11-09 21:48:50 -10:00
committed by GitHub
parent 04f9d2bf3d
commit b455dc94a2
23 changed files with 2020 additions and 387 deletions

View File

@@ -407,6 +407,7 @@ else:
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanModularPipeline",
]
@@ -1090,6 +1091,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
WanModularPipeline,
)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -88,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

View File

@@ -141,6 +141,16 @@ class AutoGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -99,6 +99,16 @@ class ClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -85,6 +85,16 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

View File

@@ -226,6 +226,16 @@ class FrequencyDecoupledGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

View File

@@ -166,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
@@ -234,6 +239,51 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod
def _prepare_batch_from_block_state(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
length 2 is provided, the first element must be the conditional data identifier and the second element
must be the unconditional data identifier or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod
@validate_hf_hub_args
def from_pretrained(

View File

@@ -187,6 +187,26 @@ class PerturbedAttentionGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward(
self,

View File

@@ -183,6 +183,26 @@ class SkipLayerGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,

View File

@@ -172,6 +172,26 @@ class SmoothedEnergyGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
@@ -74,6 +74,16 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None

View File

@@ -45,7 +45,7 @@ else:
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
@@ -90,7 +90,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import WanAutoBlocks, WanModularPipeline
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
else:
import sys

View File

@@ -1441,6 +1441,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None,
modular_config_dict: Optional[Dict[str, Any]] = None,
config_dict: Optional[Dict[str, Any]] = None,
**kwargs,
):
"""
@@ -1492,23 +1494,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
`_blocks_class_name` in the config dict
"""
if blocks is None:
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
blocks = blocks_class()
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
self.blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
# update component_specs and config_specs from modular_repo
if pretrained_model_name_or_path is not None:
if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None:
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -1524,52 +1511,59 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"local_files_only": local_files_only,
"revision": revision,
}
# try to load modular_model_index.json
try:
config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs)
except EnvironmentError as e:
logger.debug(f"modular_model_index.json not found: {e}")
config_dict = None
# update component_specs and config_specs based on modular_model_index.json
if config_dict is not None:
for name, value in config_dict.items():
# all the components in modular_model_index.json are from_pretrained components
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
library, class_name, component_spec_dict = value
component_spec = self._dict_to_component_spec(name, component_spec_dict)
component_spec.default_creation_method = "from_pretrained"
self._component_specs[name] = component_spec
modular_config_dict, config_dict = self._load_pipeline_config(
pretrained_model_name_or_path, **load_config_kwargs
)
elif name in self._config_specs:
self._config_specs[name].default = value
# if modular_model_index.json is not found, try to load model_index.json
if blocks is None:
if modular_config_dict is not None:
blocks_class_name = modular_config_dict.get("_blocks_class_name")
elif config_dict is not None:
blocks_class_name = self.get_default_blocks_name(config_dict)
else:
logger.debug(" loading config from model_index.json")
try:
from diffusers import DiffusionPipeline
blocks_class_name = None
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
blocks = blocks_class()
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
except EnvironmentError as e:
logger.debug(f" model_index.json not found in the repo: {e}")
config_dict = None
self.blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
# update component_specs and config_specs based on model_index.json
if config_dict is not None:
for name, value in config_dict.items():
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
library, class_name = value
component_spec_dict = {
"repo": pretrained_model_name_or_path,
"subfolder": name,
"type_hint": (library, class_name),
}
component_spec = self._dict_to_component_spec(name, component_spec_dict)
component_spec.default_creation_method = "from_pretrained"
self._component_specs[name] = component_spec
elif name in self._config_specs:
self._config_specs[name].default = value
# update component_specs and config_specs based on modular_model_index.json
if modular_config_dict is not None:
for name, value in modular_config_dict.items():
# all the components in modular_model_index.json are from_pretrained components
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
library, class_name, component_spec_dict = value
component_spec = self._dict_to_component_spec(name, component_spec_dict)
component_spec.default_creation_method = "from_pretrained"
self._component_specs[name] = component_spec
elif name in self._config_specs:
self._config_specs[name].default = value
# if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`)
elif config_dict is not None:
for name, value in config_dict.items():
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
library, class_name = value
component_spec_dict = {
"repo": pretrained_model_name_or_path,
"subfolder": name,
"type_hint": (library, class_name),
}
component_spec = self._dict_to_component_spec(name, component_spec_dict)
component_spec.default_creation_method = "from_pretrained"
self._component_specs[name] = component_spec
elif name in self._config_specs:
self._config_specs[name].default = value
if len(kwargs) > 0:
logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
@@ -1601,6 +1595,35 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
params[input_param.name] = input_param.default
return params
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
return self.default_blocks_name
@classmethod
def _load_pipeline_config(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
**load_config_kwargs,
):
try:
# try to load modular_model_index.json
modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
return modular_config_dict, None
except EnvironmentError as e:
logger.debug(f" modular_model_index.json not found in the repo: {e}")
try:
logger.debug(" try to load model_index.json")
from diffusers import DiffusionPipeline
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
return None, config_dict
except EnvironmentError as e:
logger.debug(f" model_index.json not found in the repo: {e}")
return None, None
@classmethod
@validate_hf_hub_args
def from_pretrained(
@@ -1655,42 +1678,33 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"revision": revision,
}
try:
# try to load modular_model_index.json
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
except EnvironmentError as e:
logger.debug(f" modular_model_index.json not found in the repo: {e}")
config_dict = None
modular_config_dict, config_dict = cls._load_pipeline_config(
pretrained_model_name_or_path, **load_config_kwargs
)
if config_dict is not None:
pipeline_class = _get_pipeline_class(cls, config=config_dict)
if modular_config_dict is not None:
pipeline_class = _get_pipeline_class(cls, config=modular_config_dict)
elif config_dict is not None:
from diffusers.pipelines.auto_pipeline import _get_model
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else:
try:
logger.debug(" try to load model_index.json")
from diffusers import DiffusionPipeline
from diffusers.pipelines.auto_pipeline import _get_model
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
except EnvironmentError as e:
logger.debug(f" model_index.json not found in the repo: {e}")
if config_dict is not None:
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else:
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
pipeline_class = cls
pretrained_model_name_or_path = None
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
pipeline_class = cls
pretrained_model_name_or_path = None
pipeline = pipeline_class(
blocks=blocks,
pretrained_model_name_or_path=pretrained_model_name_or_path,
components_manager=components_manager,
collection=collection,
modular_config_dict=modular_config_dict,
config_dict=config_dict,
**kwargs,
)
return pipeline
@@ -2134,7 +2148,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
logger.warning(
f"\nFailed to create component {name}:\n"
f"- Component spec: {spec}\n"
f"- load() called with kwargs: {component_load_kwargs}\n\n"
f"- load() called with kwargs: {component_load_kwargs}\n"
"If this component is not required for your workflow you can safely ignore this message.\n\n"
"Traceback:\n"
f"{traceback.format_exc()}"
)

View File

@@ -21,16 +21,14 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanAutoBlocks",
"WanAutoDecodeStep",
"WanAutoDenoiseStep",
"WanAutoImageEncoderStep",
"WanAutoVaeImageEncoderStep",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
@@ -41,15 +39,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .decoders import WanImageVaeDecoderStep
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
Wan22AutoBlocks,
WanAutoBlocks,
WanAutoDecodeStep,
WanAutoDenoiseStep,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
)
from .modular_pipeline import WanModularPipeline
else:

View File

@@ -13,10 +13,11 @@
# limitations under the License.
import inspect
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import torch
from ...models import WanTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
@@ -34,6 +35,97 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# configuration of guider is.
def repeat_tensor_to_batch_size(
input_name: str,
input_tensor: torch.Tensor,
batch_size: int,
num_videos_per_prompt: int = 1,
) -> torch.Tensor:
"""Repeat tensor elements to match the final batch size.
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt)
by repeating each element along dimension 0.
The input tensor must have batch size 1 or batch_size. The function will:
- If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times
- If batch size equals batch_size: repeat each element num_videos_per_prompt times
Args:
input_name (str): Name of the input tensor (used for error messages)
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
batch_size (int): The base batch size (number of prompts)
num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1.
Returns:
torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt)
Raises:
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
Examples:
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
[4, 3]
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
- shape: [4, 3]
"""
# make sure input is a tensor
if not isinstance(input_tensor, torch.Tensor):
raise ValueError(f"`{input_name}` must be a tensor")
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
if input_tensor.shape[0] == 1:
repeat_by = batch_size * num_videos_per_prompt
elif input_tensor.shape[0] == batch_size:
repeat_by = num_videos_per_prompt
else:
raise ValueError(
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
)
# expand the tensor to match the batch_size * num_videos_per_prompt
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
return input_tensor
def calculate_dimension_from_latents(
latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int
) -> Tuple[int, int]:
"""Calculate image dimensions from latent tensor dimensions.
This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by
multiplying the latent num_frames/height/width by the VAE scale factor.
Args:
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension.
Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension)
vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension.
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
Returns:
Tuple[int, int]: The calculated image dimensions as (height, width)
Raises:
ValueError: If latents tensor doesn't have 4 or 5 dimensions
"""
if latents.ndim != 5:
raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}")
_, _, num_latent_frames, latent_height, latent_width = latents.shape
num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1
height = latent_height * vae_scale_factor_spatial
width = latent_width * vae_scale_factor_spatial
return num_frames, height, width
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -94,7 +186,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class WanInputStep(ModularPipelineBlocks):
class WanTextInputStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -109,14 +201,15 @@ class WanInputStep(ModularPipelineBlocks):
)
@property
def inputs(self) -> List[InputParam]:
def expected_components(self) -> List[ComponentSpec]:
return [
InputParam("num_videos_per_prompt", default=1),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def intermediate_inputs(self) -> List[str]:
def inputs(self) -> List[InputParam]:
return [
InputParam("num_videos_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
@@ -141,19 +234,7 @@ class WanInputStep(ModularPipelineBlocks):
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative text embeddings used to guide the image generation",
description="Data type of model tensor inputs (determined by `transformer.dtype`)",
),
]
@@ -194,6 +275,140 @@ class WanInputStep(ModularPipelineBlocks):
return components, state
class WanAdditionalInputsStep(ModularPipelineBlocks):
model_name = "wan"
def __init__(
self,
image_latent_inputs: List[str] = ["first_frame_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
1. For encoded image latents, use it update height/width if None, and expands batch size
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
This is a dynamic block that allows you to configure which inputs to process.
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
a single string or list of strings. Defaults to ["first_frame_latents"].
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to [].
Examples:
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
# Configure to process multiple image latent inputs
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
)
"""
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_videos_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_frames"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate num_frames, height/width from latents
num_frames, height, width = calculate_dimension_from_latents(
image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial
)
block_state.num_frames = block_state.num_frames or num_frames
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=image_latent_tensor,
num_videos_per_prompt=block_state.num_videos_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_videos_per_prompt=block_state.num_videos_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan"
@@ -215,26 +430,15 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
InputParam("sigmas"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
block_state.device,
device,
block_state.timesteps,
block_state.sigmas,
)
@@ -246,10 +450,6 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@@ -262,11 +462,6 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
InputParam("num_frames", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_videos_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam(
"batch_size",
@@ -337,29 +532,106 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
device = components._execution_device
dtype = torch.float32 # Wan latents should be torch.float32 for best quality
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.num_frames = block_state.num_frames or components.default_num_frames
block_state.device = components._execution_device
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
block_state.latents = self.prepare_latents(
components,
block_state.batch_size * block_state.num_videos_per_prompt,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.num_frames,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
batch_size=block_state.batch_size * block_state.num_videos_per_prompt,
num_channels_latents=components.num_channels_latents,
height=block_state.height,
width=block_state.width,
num_frames=block_state.num_frames,
dtype=dtype,
device=device,
generator=block_state.generator,
latents=block_state.latents,
)
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.first_last_frame_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanDecodeStep(ModularPipelineBlocks):
class WanImageVaeDecoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -50,12 +50,6 @@ class WanDecodeStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
@@ -80,25 +74,20 @@ class WanDecodeStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
vae_dtype = components.vae.dtype
if not block_state.output_type == "latent":
latents = block_state.latents
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
latents = latents.to(vae_dtype)
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
else:
block_state.videos = block_state.latents
block_state.videos = components.video_processor.postprocess_video(
block_state.videos, output_type=block_state.output_type
latents = block_state.latents
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
latents = latents.to(vae_dtype)
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np")
self.set_block_state(state, block_state)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple
import torch
@@ -27,16 +27,156 @@ from ..modular_pipeline import (
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam
from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
return components, block_state
class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"first_frame_latents",
required=True,
type_hint=torch.Tensor,
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
block_state.dtype
)
return components, block_state
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"first_last_frame_latents",
required=True,
type_hint=torch.Tensor,
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat(
[block_state.latents, block_state.first_last_frame_latents], dim=1
).to(block_state.dtype)
return components, block_state
class WanLoopDenoiser(ModularPipelineBlocks):
model_name = "wan"
def __init__(
self,
guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
):
"""Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1.
Args:
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
(for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
- A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
"negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
'encoder_hidden_states'.
- A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
`block_state.image_embeds` for both conditional and unconditional batches.
"""
if not isinstance(guider_input_fields, dict):
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
self._guider_input_fields = guider_input_fields
super().__init__()
@property
def expected_components(self) -> List[ComponentSpec]:
return [
@@ -59,49 +199,30 @@ class WanLoopDenoiser(ModularPipelineBlocks):
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
inputs = [
InputParam("attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
"Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
guider_input_names = []
for value in self._guider_input_fields.values():
if isinstance(value, tuple):
guider_input_names.extend(value)
else:
guider_input_names.append(value)
for name in guider_input_names:
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
return inputs
@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
guider_inputs = {
"prompt_embeds": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
@@ -112,22 +233,26 @@ class WanLoopDenoiser(ModularPipelineBlocks):
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs(guider_inputs)
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latents.to(transformer_dtype),
timestep=t.flatten(),
encoder_hidden_states=prompt_embeds,
hidden_states=block_state.latent_model_input.to(block_state.dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
components.guider.cleanup_models(components.transformer)
@@ -137,6 +262,141 @@ class WanLoopDenoiser(ModularPipelineBlocks):
return components, block_state
class Wan22LoopDenoiser(ModularPipelineBlocks):
model_name = "wan"
def __init__(
self,
guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")},
):
"""Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2.
Args:
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
(for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either:
- A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds",
"negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
`encoder_hidden_states`.
- A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward
`block_state.image_embeds` for both conditional and unconditional batches.
"""
if not isinstance(guider_input_fields, dict):
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
self._guider_input_fields = guider_input_fields
super().__init__()
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
ComponentSpec(
"guider_2",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 3.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", WanTransformer3DModel),
ComponentSpec("transformer_2", WanTransformer3DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(
name="boundary_ratio",
default=0.875,
description="The boundary ratio to divide the denoising loop into high noise and low noise stages.",
),
]
@property
def inputs(self) -> List[Tuple[str, Any]]:
inputs = [
InputParam("attention_kwargs"),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
guider_input_names = []
for value in self._guider_input_fields.values():
if isinstance(value, tuple):
guider_input_names.extend(value)
else:
guider_input_names.append(value)
for name in guider_input_names:
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
return inputs
@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps
if t >= boundary_timestep:
block_state.current_model = components.transformer
block_state.guider = components.guider
else:
block_state.current_model = components.transformer_2
block_state.guider = components.guider_2
block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
block_state.guider.prepare_models(block_state.current_model)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = block_state.current_model(
hidden_states=block_state.latent_model_input.to(block_state.dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
block_state.guider.cleanup_models(block_state.current_model)
# Perform guidance
block_state.noise_pred = block_state.guider(guider_state)[0]
return components, block_state
class WanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "wan"
@@ -154,20 +414,6 @@ class WanLoopAfterDenoiser(ModularPipelineBlocks):
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
@@ -198,18 +444,11 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0}),
default_creation_method="from_config",
),
ComponentSpec("scheduler", UniPCMultistepScheduler),
ComponentSpec("transformer", WanTransformer3DModel),
]
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
@@ -248,7 +487,12 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
class WanDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanLoopDenoiser,
WanLoopBeforeDenoiser,
WanLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@@ -259,7 +503,110 @@ class WanDenoiseStep(WanDenoiseLoopWrapper):
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanLoopBeforeDenoiser`\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports both text2vid tasks."
"This block supports text-to-video tasks for wan2.1."
)
class Wan22DenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanLoopBeforeDenoiser,
Wan22LoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanLoopBeforeDenoiser`\n"
" - `Wan22LoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports text-to-video tasks for Wan2.2."
)
class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanImage2VideoLoopBeforeDenoiser,
WanLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_image": "image_embeds",
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanImage2VideoLoopBeforeDenoiser`\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports image-to-video tasks for wan2.1."
)
class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanImage2VideoLoopBeforeDenoiser,
Wan22LoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanImage2VideoLoopBeforeDenoiser`\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports image-to-video tasks for Wan2.2."
)
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanFLF2VLoopBeforeDenoiser,
WanLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_image": "image_embeds",
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanFLF2VLoopBeforeDenoiser`\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports FLF2V tasks for wan2.1."
)

View File

@@ -15,21 +15,29 @@
import html
from typing import List, Optional, Union
import numpy as np
import PIL
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...utils import is_ftfy_available, logging
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan
from ...utils import is_ftfy_available, is_torchvision_available, logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline
if is_ftfy_available():
import ftfy
if is_torchvision_available():
from torchvision import transforms
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -51,6 +59,103 @@ def prompt_clean(text):
return text
def get_t5_prompt_embeds(
text_encoder: UMT5EncoderModel,
tokenizer: AutoTokenizer,
prompt: Union[str, List[str]],
max_sequence_length: int,
device: torch.device,
):
dtype = text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
return prompt_embeds
def encode_image(
image: PipelineImageInput,
image_processor: CLIPImageProcessor,
image_encoder: CLIPVisionModel,
device: Optional[torch.device] = None,
):
image = image_processor(images=image, return_tensors="pt").to(device)
image_embeds = image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def encode_vae_image(
video_tensor: torch.Tensor,
vae: AutoencoderKLWan,
generator: torch.Generator,
device: torch.device,
dtype: torch.dtype,
latent_channels: int = 16,
):
if not isinstance(video_tensor, torch.Tensor):
raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.")
if isinstance(generator, list) and len(generator) != video_tensor.shape[0]:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}."
)
video_tensor = video_tensor.to(device=device, dtype=dtype)
if isinstance(generator, list):
video_latents = [
retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax")
for i in range(video_tensor.shape[0])
]
video_latents = torch.cat(video_latents, dim=0)
else:
video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax")
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, latent_channels, 1, 1, 1)
.to(video_latents.device, video_latents.dtype)
)
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to(
video_latents.device, video_latents.dtype
)
video_latents = (video_latents - latents_mean) * latents_std
return video_latents
class WanTextEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@@ -71,16 +176,12 @@ class WanTextEncoderStep(ModularPipelineBlocks):
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("negative_prompt"),
InputParam("attention_kwargs"),
InputParam("max_sequence_length", default=512),
]
@property
@@ -107,47 +208,13 @@ class WanTextEncoderStep(ModularPipelineBlocks):
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]],
max_sequence_length: int,
device: torch.device,
):
dtype = components.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
text_inputs = components.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: str,
device: Optional[torch.device] = None,
num_videos_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
):
r"""
@@ -158,32 +225,29 @@ class WanTextEncoderStep(ModularPipelineBlocks):
prompt to be encoded
device: (`torch.device`):
torch device
num_videos_per_prompt (`int`):
number of videos that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
max_sequence_length (`int`, defaults to `512`):
The maximum number of text tokens to be used for the generation process.
"""
device = device or components._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = len(prompt)
if prompt_embeds is None:
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
prompt_embeds = get_t5_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if prepare_unconditional_embeds and negative_prompt_embeds is None:
if prepare_unconditional_embeds:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
@@ -199,18 +263,14 @@ class WanTextEncoderStep(ModularPipelineBlocks):
" the batch size of `prompt`."
)
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
components, negative_prompt, max_sequence_length, device
negative_prompt_embeds = get_t5_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
@torch.no_grad()
@@ -219,7 +279,6 @@ class WanTextEncoderStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
@@ -227,16 +286,382 @@ class WanTextEncoderStep(ModularPipelineBlocks):
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
prompt_embeds=None,
negative_prompt_embeds=None,
components=components,
prompt=block_state.prompt,
device=block_state.device,
prepare_unconditional_embeds=components.requires_unconditional_embeds,
negative_prompt=block_state.negative_prompt,
max_sequence_length=block_state.max_sequence_length,
)
# Add outputs
self.set_block_state(state, block_state)
return components, state
class WanImageResizeStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", type_hint=PIL.Image.Image, required=True),
InputParam("height", type_hint=int, default=480),
InputParam("width", type_hint=int, default=832),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("resized_image", type_hint=PIL.Image.Image),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
max_area = block_state.height * block_state.width
image = block_state.image
aspect_ratio = image.height / image.width
mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial
block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
block_state.resized_image = image.resize((block_state.width, block_state.height))
self.set_block_state(state, block_state)
return components, state
class WanImageCropResizeStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Image Resize step that resize the last_image to the same size of first frame image with center crop."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image"
),
InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("resized_last_image", type_hint=PIL.Image.Image),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
height = block_state.resized_image.height
width = block_state.resized_image.width
image = block_state.last_image
# Calculate resize ratio to match first frame dimensions
resize_ratio = max(width / image.width, height / image.height)
# Resize the image
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
resized_image = transforms.functional.center_crop(image, size)
block_state.resized_last_image = resized_image
self.set_block_state(state, block_state)
return components, state
class WanImageEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_processor", CLIPImageProcessor),
ComponentSpec("image_encoder", CLIPVisionModel),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
image = block_state.resized_image
image_embeds = encode_image(
image_processor=components.image_processor,
image_encoder=components.image_encoder,
image=image,
device=device,
)
block_state.image_embeds = image_embeds
self.set_block_state(state, block_state)
return components, state
class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_processor", CLIPImageProcessor),
ComponentSpec("image_encoder", CLIPVisionModel),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
first_frame_image = block_state.resized_image
last_frame_image = block_state.resized_last_image
image_embeds = encode_image(
image_processor=components.image_processor,
image_encoder=components.image_encoder,
image=[first_frame_image, last_frame_image],
device=device,
)
block_state.image_embeds = image_embeds
self.set_block_state(state, block_state)
return components, state
class WanVaeImageEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLWan),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames"),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"first_frame_latents",
type_hint=torch.Tensor,
description="video latent representation with the first frame image condition",
),
]
@staticmethod
def check_inputs(components, block_state):
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
)
if block_state.num_frames is not None and (
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
):
raise ValueError(
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
)
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
image = block_state.resized_image
device = components._execution_device
dtype = torch.float32
height = block_state.height or components.default_height
width = block_state.width or components.default_width
num_frames = block_state.num_frames or components.default_num_frames
image_tensor = components.video_processor.preprocess(image, height=height, width=width).to(
device=device, dtype=dtype
)
if image_tensor.dim() == 4:
image_tensor = image_tensor.unsqueeze(2)
video_tensor = torch.cat(
[
image_tensor,
image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width),
],
dim=2,
).to(device=device, dtype=dtype)
block_state.first_frame_latents = encode_vae_image(
video_tensor=video_tensor,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
)
self.set_block_state(state, block_state)
return components, state
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLWan),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames"),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"first_last_frame_latents",
type_hint=torch.Tensor,
description="video latent representation with the first and last frame images condition",
),
]
@staticmethod
def check_inputs(components, block_state):
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
):
raise ValueError(
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
)
if block_state.num_frames is not None and (
block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
):
raise ValueError(
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
)
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
first_frame_image = block_state.resized_image
last_frame_image = block_state.resized_last_image
device = components._execution_device
dtype = torch.float32
height = block_state.height or components.default_height
width = block_state.width or components.default_width
num_frames = block_state.num_frames or components.default_num_frames
first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to(
device=device, dtype=dtype
)
first_image_tensor = first_image_tensor.unsqueeze(2)
last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to(
device=device, dtype=dtype
)
last_image_tensor = last_image_tensor.unsqueeze(2)
video_tensor = torch.cat(
[
first_image_tensor,
first_image_tensor.new_zeros(
first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width
),
last_image_tensor,
],
dim=2,
).to(device=device, dtype=dtype)
block_state.first_last_frame_latents = encode_vae_image(
video_tensor=video_tensor,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -16,96 +16,244 @@ from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
WanInputStep,
WanAdditionalInputsStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanImageVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
Wan22Image2VideoDenoiseStep,
WanDenoiseStep,
WanFLF2VDenoiseStep,
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeImageEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanTextEncoderStep,
WanVaeImageEncoderStep,
)
from .decoders import WanDecodeStep
from .denoise import WanDenoiseStep
from .encoders import WanTextEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
# wan2.1
# wan2.1: text2vid
class WanCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanInputStep,
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
]
block_names = ["input", "set_timesteps", "prepare_latents"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
)
# before_denoise: all task (text2vid,)
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanBeforeDenoiseStep,
]
block_names = ["text2vid"]
block_trigger_inputs = [None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2vid.\n"
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
)
# denoise: text2vid
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanDenoiseStep,
]
block_names = ["denoise"]
block_trigger_inputs = [None]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: image2video
## image encoder
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
## vae encoder
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
block_names = ["image_resize", "vae_image_encoder"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
## denoise
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: FLF2v
## image encoder
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
## vae encoder
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
## denoise
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanFLF2VDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_last_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: auto blocks
## image encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## vae encoder
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## denoise
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanFLF2VCoreDenoiseStep,
WanImage2VideoCoreDenoiseStep,
WanCoreDenoiseStep,
]
block_names = ["flf2v", "image2video", "text2video"]
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2vid tasks.."
" - `WanDenoiseStep` (denoise) for text2vid tasks."
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
)
# decode: all task (text2img, img2img, inpainting)
class WanAutoDecodeStep(AutoPipelineBlocks):
block_classes = [WanDecodeStep]
block_names = ["non-inpaint"]
block_trigger_inputs = [None]
@property
def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
# text2vid
# auto pipeline blocks
class WanAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoBeforeDenoiseStep,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
WanAutoDenoiseStep,
WanAutoDecodeStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"before_denoise",
"image_encoder",
"vae_image_encoder",
"denoise",
"decoder",
"decode",
]
@property
@@ -116,29 +264,211 @@ class WanAutoBlocks(SequentialPipelineBlocks):
)
# wan22
# wan2.2: text2vid
## denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# wan2.2: image2video
## denoise
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
Wan22Image2VideoCoreDenoiseStep,
Wan22CoreDenoiseStep,
]
block_names = ["image2video", "text2video"]
block_trigger_inputs = ["first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
)
class Wan22AutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoVaeImageEncoderStep,
Wan22AutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_image_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# presets for wan2.1 and wan2.2
# YiYi Notes: should we move these to doc?
# wan2.1
TEXT2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanInputStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", WanDenoiseStep),
("decode", WanDecodeStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("image_encoder", WanImage2VideoImageEncoderStep),
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
("denoise", WanImage2VideoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
FLF2V_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("last_image_resize", WanImageCropResizeStep),
("image_encoder", WanFLF2VImageEncoderStep),
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
("denoise", WanFLF2VDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("before_denoise", WanAutoBeforeDenoiseStep),
("image_encoder", WanAutoImageEncoderStep),
("vae_image_encoder", WanAutoVaeImageEncoderStep),
("denoise", WanAutoDenoiseStep),
("decode", WanAutoDecodeStep),
("decode", WanImageVaeDecoderStep),
]
)
# wan2.2 presets
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("image_resize", WanImageResizeStep),
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("vae_image_encoder", WanAutoVaeImageEncoderStep),
("denoise", Wan22AutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# presets all blocks (wan and wan22)
ALL_BLOCKS = {
"text2video": TEXT2VIDEO_BLOCKS,
"auto": AUTO_BLOCKS,
"wan2.1": {
"text2video": TEXT2VIDEO_BLOCKS,
"image2video": IMAGE2VIDEO_BLOCKS,
"flf2v": FLF2V_BLOCKS,
"auto": AUTO_BLOCKS,
},
"wan2.2": {
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
"auto": AUTO_BLOCKS_WAN22,
},
}

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import WanLoraLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...utils import logging
@@ -35,6 +37,13 @@ class WanModularPipeline(
default_blocks_name = "WanAutoBlocks"
# override the default_blocks_name in base class, which is just return self.default_blocks_name
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22AutoBlocks"
else:
return "WanAutoBlocks"
@property
def default_height(self):
return self.default_sample_height * self.vae_scale_factor_spatial
@@ -59,6 +68,13 @@ class WanModularPipeline(
def default_sample_num_frames(self):
return 21
@property
def patch_size_spatial(self):
patch_size_spatial = 2
if hasattr(self, "transformer") and self.transformer is not None:
patch_size_spatial = self.transformer.config.patch_size[1]
return patch_size_spatial
@property
def vae_scale_factor_spatial(self):
vae_scale_factor = 8
@@ -86,3 +102,19 @@ class WanModularPipeline(
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.z_dim
return num_channels_latents
@property
def requires_unconditional_embeds(self):
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
@property
def num_train_timesteps(self):
num_train_timesteps = 1000
if hasattr(self, "scheduler") and self.scheduler is not None:
num_train_timesteps = self.scheduler.config.num_train_timesteps
return num_train_timesteps

View File

@@ -117,6 +117,7 @@ from .stable_diffusion_xl import (
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
@@ -214,6 +215,24 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
]
)
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanPipeline),
]
)
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanImageToVideoPipeline),
]
)
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanVideoToVideoPipeline),
]
)
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky", KandinskyPipeline),
@@ -247,6 +266,9 @@ SUPPORTED_TASKS_MAPPINGS = [
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING,
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING,
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,

View File

@@ -182,6 +182,21 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Wan22AutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]