diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a6b90831f7..a5040bd283 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -407,6 +407,7 @@ else: "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", + "Wan22AutoBlocks", "WanAutoBlocks", "WanModularPipeline", ] @@ -1090,6 +1091,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, + Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline, ) diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 492d10d2f1..8ec30d02d7 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -88,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/adaptive_projected_guidance_mix.py b/src/diffusers/guiders/adaptive_projected_guidance_mix.py index 732741fc92..bdc97bcf62 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance_mix.py +++ b/src/diffusers/guiders/adaptive_projected_guidance_mix.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + if self._step == 0: + if self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index 4374f45aff..b7f62e2f4a 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -141,6 +141,16 @@ class AutoGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index d475b30226..5e55d4d869 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -99,6 +99,16 @@ class ClassifierFreeGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 1ea6bbb1c8..23b492e51b 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -85,6 +85,16 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py index cd542a43a4..4ec6e2d36d 100644 --- a/src/diffusers/guiders/frequency_decoupled_guidance.py +++ b/src/diffusers/guiders/frequency_decoupled_guidance.py @@ -226,6 +226,16 @@ class FrequencyDecoupledGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 71e4becfcd..52cb0ce349 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -166,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.") + def __call__(self, data: List["BlockState"]) -> Any: if not all(hasattr(d, "noise_pred") for d in data): raise ValueError("Expected all data to have `noise_pred` attribute.") @@ -234,6 +239,51 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) + @classmethod + def _prepare_batch_from_block_state( + cls, + input_fields: Dict[str, Union[str, Tuple[str, str]]], + data: "BlockState", + tuple_index: int, + identifier: str, + ) -> "BlockState": + """ + Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the + `BaseGuidance` class. It prepares the batch based on the provided tuple index. + + Args: + input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): + A dictionary where the keys are the names of the fields that will be used to store the data once it is + prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used + to look up the required data provided for preparation. If a string is provided, it will be used as the + conditional data (or unconditional if used with a guidance method that requires it). If a tuple of + length 2 is provided, the first element must be the conditional data identifier and the second element + must be the unconditional data identifier or None. + data (`BlockState`): + The input data to be prepared. + tuple_index (`int`): + The index to use when accessing input fields that are tuples. + + Returns: + `BlockState`: The prepared batch of data. + """ + from ..modular_pipelines.modular_pipeline import BlockState + + data_batch = {} + for key, value in input_fields.items(): + try: + if isinstance(value, str): + data_batch[key] = getattr(data, value) + elif isinstance(value, tuple): + data_batch[key] = getattr(data, value[tuple_index]) + else: + # We've already checked that value is a string or a tuple of strings with length 2 + pass + except AttributeError: + logger.debug(f"`data` does not have attribute(s) {value}, skipping.") + data_batch[cls._identifier_key] = identifier + return BlockState(**data_batch) + @classmethod @validate_hf_hub_args def from_pretrained( diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py index 29341736e8..f233e90ca4 100644 --- a/src/diffusers/guiders/perturbed_attention_guidance.py +++ b/src/diffusers/guiders/perturbed_attention_guidance.py @@ -187,6 +187,26 @@ class PerturbedAttentionGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward def forward( self, diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index fa5b93b680..e6109300d9 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -183,6 +183,26 @@ class SkipLayerGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward( self, pred_cond: torch.Tensor, diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 7446b33f12..6c3906e820 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -172,6 +172,26 @@ class SmoothedEnergyGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + if self.num_conditions == 1: + tuple_indices = [0] + input_predictions = ["pred_cond"] + elif self.num_conditions == 2: + tuple_indices = [0, 1] + input_predictions = ( + ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"] + ) + else: + tuple_indices = [0, 1, 0] + input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward( self, pred_cond: torch.Tensor, diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index cfa3c4a616..76899c6e84 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch @@ -74,6 +74,16 @@ class TangentialClassifierFreeGuidance(BaseGuidance): data_batches.append(data_batch) return data_batches + def prepare_inputs_from_block_state( + self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]] + ) -> List["BlockState"]: + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] + data_batches = [] + for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): + data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) + data_batches.append(data_batch) + return data_batches + def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: pred = None diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 86ed735134..252b9f33df 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -45,7 +45,7 @@ else: "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] - _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] + _import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"] _import_structure["flux"] = [ "FluxAutoBlocks", "FluxModularPipeline", @@ -90,7 +90,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline - from .wan import WanAutoBlocks, WanModularPipeline + from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 307698245e..151adbbc03 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1441,6 +1441,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, + modular_config_dict: Optional[Dict[str, Any]] = None, + config_dict: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -1492,23 +1494,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as `_blocks_class_name` in the config dict """ - if blocks is None: - blocks_class_name = self.default_blocks_name - if blocks_class_name is not None: - diffusers_module = importlib.import_module("diffusers") - blocks_class = getattr(diffusers_module, blocks_class_name) - blocks = blocks_class() - else: - logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") - self.blocks = blocks - self._components_manager = components_manager - self._collection = collection - self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} - self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} - - # update component_specs and config_specs from modular_repo - if pretrained_model_name_or_path is not None: + if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None: cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1524,52 +1511,59 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): "local_files_only": local_files_only, "revision": revision, } - # try to load modular_model_index.json - try: - config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f"modular_model_index.json not found: {e}") - config_dict = None - # update component_specs and config_specs based on modular_model_index.json - if config_dict is not None: - for name, value in config_dict.items(): - # all the components in modular_model_index.json are from_pretrained components - if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - component_spec = self._dict_to_component_spec(name, component_spec_dict) - component_spec.default_creation_method = "from_pretrained" - self._component_specs[name] = component_spec + modular_config_dict, config_dict = self._load_pipeline_config( + pretrained_model_name_or_path, **load_config_kwargs + ) - elif name in self._config_specs: - self._config_specs[name].default = value - - # if modular_model_index.json is not found, try to load model_index.json + if blocks is None: + if modular_config_dict is not None: + blocks_class_name = modular_config_dict.get("_blocks_class_name") + elif config_dict is not None: + blocks_class_name = self.get_default_blocks_name(config_dict) else: - logger.debug(" loading config from model_index.json") - try: - from diffusers import DiffusionPipeline + blocks_class_name = None + if blocks_class_name is not None: + diffusers_module = importlib.import_module("diffusers") + blocks_class = getattr(diffusers_module, blocks_class_name) + blocks = blocks_class() + else: + logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") - config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f" model_index.json not found in the repo: {e}") - config_dict = None + self.blocks = blocks + self._components_manager = components_manager + self._collection = collection + self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} + self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} - # update component_specs and config_specs based on model_index.json - if config_dict is not None: - for name, value in config_dict.items(): - if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2: - library, class_name = value - component_spec_dict = { - "repo": pretrained_model_name_or_path, - "subfolder": name, - "type_hint": (library, class_name), - } - component_spec = self._dict_to_component_spec(name, component_spec_dict) - component_spec.default_creation_method = "from_pretrained" - self._component_specs[name] = component_spec - elif name in self._config_specs: - self._config_specs[name].default = value + # update component_specs and config_specs based on modular_model_index.json + if modular_config_dict is not None: + for name, value in modular_config_dict.items(): + # all the components in modular_model_index.json are from_pretrained components + if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3: + library, class_name, component_spec_dict = value + component_spec = self._dict_to_component_spec(name, component_spec_dict) + component_spec.default_creation_method = "from_pretrained" + self._component_specs[name] = component_spec + + elif name in self._config_specs: + self._config_specs[name].default = value + + # if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`) + elif config_dict is not None: + for name, value in config_dict.items(): + if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2: + library, class_name = value + component_spec_dict = { + "repo": pretrained_model_name_or_path, + "subfolder": name, + "type_hint": (library, class_name), + } + component_spec = self._dict_to_component_spec(name, component_spec_dict) + component_spec.default_creation_method = "from_pretrained" + self._component_specs[name] = component_spec + elif name in self._config_specs: + self._config_specs[name].default = value if len(kwargs) > 0: logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.") @@ -1601,6 +1595,35 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): params[input_param.name] = input_param.default return params + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + return self.default_blocks_name + + @classmethod + def _load_pipeline_config( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + **load_config_kwargs, + ): + try: + # try to load modular_model_index.json + modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) + return modular_config_dict, None + + except EnvironmentError as e: + logger.debug(f" modular_model_index.json not found in the repo: {e}") + + try: + logger.debug(" try to load model_index.json") + from diffusers import DiffusionPipeline + + config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) + return None, config_dict + + except EnvironmentError as e: + logger.debug(f" model_index.json not found in the repo: {e}") + + return None, None + @classmethod @validate_hf_hub_args def from_pretrained( @@ -1655,42 +1678,33 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): "revision": revision, } - try: - # try to load modular_model_index.json - config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f" modular_model_index.json not found in the repo: {e}") - config_dict = None + modular_config_dict, config_dict = cls._load_pipeline_config( + pretrained_model_name_or_path, **load_config_kwargs + ) - if config_dict is not None: - pipeline_class = _get_pipeline_class(cls, config=config_dict) + if modular_config_dict is not None: + pipeline_class = _get_pipeline_class(cls, config=modular_config_dict) + elif config_dict is not None: + from diffusers.pipelines.auto_pipeline import _get_model + + logger.debug(" try to determine the modular pipeline class from model_index.json") + standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) + model_name = _get_model(standard_pipeline_class.__name__) + pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__) + diffusers_module = importlib.import_module("diffusers") + pipeline_class = getattr(diffusers_module, pipeline_class_name) else: - try: - logger.debug(" try to load model_index.json") - from diffusers import DiffusionPipeline - from diffusers.pipelines.auto_pipeline import _get_model - - config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) - except EnvironmentError as e: - logger.debug(f" model_index.json not found in the repo: {e}") - - if config_dict is not None: - logger.debug(" try to determine the modular pipeline class from model_index.json") - standard_pipeline_class = _get_pipeline_class(cls, config=config_dict) - model_name = _get_model(standard_pipeline_class.__name__) - pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__) - diffusers_module = importlib.import_module("diffusers") - pipeline_class = getattr(diffusers_module, pipeline_class_name) - else: - # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components - pipeline_class = cls - pretrained_model_name_or_path = None + # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components + pipeline_class = cls + pretrained_model_name_or_path = None pipeline = pipeline_class( blocks=blocks, pretrained_model_name_or_path=pretrained_model_name_or_path, components_manager=components_manager, collection=collection, + modular_config_dict=modular_config_dict, + config_dict=config_dict, **kwargs, ) return pipeline @@ -2134,7 +2148,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): logger.warning( f"\nFailed to create component {name}:\n" f"- Component spec: {spec}\n" - f"- load() called with kwargs: {component_load_kwargs}\n\n" + f"- load() called with kwargs: {component_load_kwargs}\n" + "If this component is not required for your workflow you can safely ignore this message.\n\n" + "Traceback:\n" f"{traceback.format_exc()}" ) diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py index 7b548e003c..73f67c9afe 100644 --- a/src/diffusers/modular_pipelines/wan/__init__.py +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -21,16 +21,14 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["decoders"] = ["WanImageVaeDecoderStep"] _import_structure["encoders"] = ["WanTextEncoderStep"] _import_structure["modular_blocks"] = [ "ALL_BLOCKS", - "AUTO_BLOCKS", - "TEXT2VIDEO_BLOCKS", - "WanAutoBeforeDenoiseStep", + "Wan22AutoBlocks", "WanAutoBlocks", - "WanAutoBlocks", - "WanAutoDecodeStep", - "WanAutoDenoiseStep", + "WanAutoImageEncoderStep", + "WanAutoVaeImageEncoderStep", ] _import_structure["modular_pipeline"] = ["WanModularPipeline"] @@ -41,15 +39,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: + from .decoders import WanImageVaeDecoderStep from .encoders import WanTextEncoderStep from .modular_blocks import ( ALL_BLOCKS, - AUTO_BLOCKS, - TEXT2VIDEO_BLOCKS, - WanAutoBeforeDenoiseStep, + Wan22AutoBlocks, WanAutoBlocks, - WanAutoDecodeStep, - WanAutoDenoiseStep, + WanAutoImageEncoderStep, + WanAutoVaeImageEncoderStep, ) from .modular_pipeline import WanModularPipeline else: diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index d48f678edd..e2f8d3e7d8 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -13,10 +13,11 @@ # limitations under the License. import inspect -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import torch +from ...models import WanTransformer3DModel from ...schedulers import UniPCMultistepScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -34,6 +35,97 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # configuration of guider is. +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times + - If batch size equals batch_size: repeat each element num_videos_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_videos_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_videos_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_videos_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +def calculate_dimension_from_latents( + latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int +) -> Tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by + multiplying the latent num_frames/height/width by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension. + Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension) + vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + Tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + if latents.ndim != 5: + raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}") + + _, _, num_latent_frames, latent_height, latent_width = latents.shape + + num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1 + height = latent_height * vae_scale_factor_spatial + width = latent_width * vae_scale_factor_spatial + + return num_frames, height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -94,7 +186,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class WanInputStep(ModularPipelineBlocks): +class WanTextInputStep(ModularPipelineBlocks): model_name = "wan" @property @@ -109,14 +201,15 @@ class WanInputStep(ModularPipelineBlocks): ) @property - def inputs(self) -> List[InputParam]: + def expected_components(self) -> List[ComponentSpec]: return [ - InputParam("num_videos_per_prompt", default=1), + ComponentSpec("transformer", WanTransformer3DModel), ] @property - def intermediate_inputs(self) -> List[str]: + def inputs(self) -> List[InputParam]: return [ + InputParam("num_videos_per_prompt", default=1), InputParam( "prompt_embeds", required=True, @@ -141,19 +234,7 @@ class WanInputStep(ModularPipelineBlocks): OutputParam( "dtype", type_hint=torch.dtype, - description="Data type of model tensor inputs (determined by `prompt_embeds`)", - ), - OutputParam( - "prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields - description="text embeddings used to guide the image generation", - ), - OutputParam( - "negative_prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields - description="negative text embeddings used to guide the image generation", + description="Data type of model tensor inputs (determined by `transformer.dtype`)", ), ] @@ -194,6 +275,140 @@ class WanInputStep(ModularPipelineBlocks): return components, state +class WanAdditionalInputsStep(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + image_latent_inputs: List[str] = ["first_frame_latents"], + additional_batch_inputs: List[str] = [], + ): + """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" + + This step handles multiple common tasks to prepare inputs for the denoising step: + 1. For encoded image latents, use it update height/width if None, and expands batch size + 2. For additional_batch_inputs: Only expands batch dimensions to match final batch size + + This is a dynamic block that allows you to configure which inputs to process. + + Args: + image_latent_inputs (List[str], optional): Names of image latent tensors to process. + In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be + a single string or list of strings. Defaults to ["first_frame_latents"]. + additional_batch_inputs (List[str], optional): + Names of additional conditional input tensors to expand batch size. These tensors will only have their + batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. + Defaults to []. + + Examples: + # Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep() + + # Configure to process multiple image latent inputs + WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"]) + + # Configure to process image latents and additional batch inputs WanAdditionalInputsStep( + image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"] + ) + """ + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_videos_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + InputParam(name="num_frames"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate num_frames, height/width from latents + num_frames, height, width = calculate_dimension_from_latents( + image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial + ) + block_state.num_frames = block_state.num_frames or num_frames + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_videos_per_prompt=block_state.num_videos_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + class WanSetTimestepsStep(ModularPipelineBlocks): model_name = "wan" @@ -215,26 +430,15 @@ class WanSetTimestepsStep(ModularPipelineBlocks): InputParam("sigmas"), ] - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), - OutputParam( - "num_inference_steps", - type_hint=int, - description="The number of denoising steps to perform at inference time", - ), - ] - @torch.no_grad() def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device + device = components._execution_device block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( components.scheduler, block_state.num_inference_steps, - block_state.device, + device, block_state.timesteps, block_state.sigmas, ) @@ -246,10 +450,6 @@ class WanSetTimestepsStep(ModularPipelineBlocks): class WanPrepareLatentsStep(ModularPipelineBlocks): model_name = "wan" - @property - def expected_components(self) -> List[ComponentSpec]: - return [] - @property def description(self) -> str: return "Prepare latents step that prepares the latents for the text-to-video generation process" @@ -262,11 +462,6 @@ class WanPrepareLatentsStep(ModularPipelineBlocks): InputParam("num_frames", type_hint=int), InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("num_videos_per_prompt", type_hint=int, default=1), - ] - - @property - def intermediate_inputs(self) -> List[InputParam]: - return [ InputParam("generator"), InputParam( "batch_size", @@ -337,29 +532,106 @@ class WanPrepareLatentsStep(ModularPipelineBlocks): @torch.no_grad() def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + device = components._execution_device + dtype = torch.float32 # Wan latents should be torch.float32 for best quality block_state.height = block_state.height or components.default_height block_state.width = block_state.width or components.default_width block_state.num_frames = block_state.num_frames or components.default_num_frames - block_state.device = components._execution_device - block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality - block_state.num_channels_latents = components.num_channels_latents - - self.check_inputs(components, block_state) block_state.latents = self.prepare_latents( components, - block_state.batch_size * block_state.num_videos_per_prompt, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.num_frames, - block_state.dtype, - block_state.device, - block_state.generator, - block_state.latents, + batch_size=block_state.batch_size * block_state.num_videos_per_prompt, + num_channels_latents=components.num_channels_latents, + height=block_state.height, + width=block_state.width, + num_frames=block_state.num_frames, + dtype=dtype, + device=device, + generator=block_state.generator, + latents=block_state.latents, ) self.set_block_state(state, block_state) return components, state + + +class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked first frame latents and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", type_hint=int), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device) + block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "step that prepares the masked latents with first and last frames and add it to the latent condition" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]), + InputParam("num_frames", type_hint=int), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape + + mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width) + mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0 + + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal + ) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device) + block_state.first_last_frame_latents = torch.concat( + [mask_lat_size, block_state.first_last_frame_latents], dim=1 + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py index 8c751172d8..7cec318c17 100644 --- a/src/diffusers/modular_pipelines/wan/decoders.py +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class WanDecodeStep(ModularPipelineBlocks): +class WanImageVaeDecoderStep(ModularPipelineBlocks): model_name = "wan" @property @@ -50,12 +50,6 @@ class WanDecodeStep(ModularPipelineBlocks): @property def inputs(self) -> List[Tuple[str, Any]]: - return [ - InputParam("output_type", default="pil"), - ] - - @property - def intermediate_inputs(self) -> List[str]: return [ InputParam( "latents", @@ -80,25 +74,20 @@ class WanDecodeStep(ModularPipelineBlocks): block_state = self.get_block_state(state) vae_dtype = components.vae.dtype - if not block_state.output_type == "latent": - latents = block_state.latents - latents_mean = ( - torch.tensor(components.vae.config.latents_mean) - .view(1, components.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( - 1, components.vae.config.z_dim, 1, 1, 1 - ).to(latents.device, latents.dtype) - latents = latents / latents_std + latents_mean - latents = latents.to(vae_dtype) - block_state.videos = components.vae.decode(latents, return_dict=False)[0] - else: - block_state.videos = block_state.latents - - block_state.videos = components.video_processor.postprocess_video( - block_state.videos, output_type=block_state.output_type + latents = block_state.latents + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + latents = latents.to(vae_dtype) + block_state.videos = components.vae.decode(latents, return_dict=False)[0] + + block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np") self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 4f3ca80acc..2da36f52da 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import torch @@ -27,16 +27,156 @@ from ..modular_pipeline import ( ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam from .modular_pipeline import WanModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class WanLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = block_state.latents.to(block_state.dtype) + return components, block_state + + +class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "first_frame_latents", + required=True, + type_hint=torch.Tensor, + description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to( + block_state.dtype + ) + return components, block_state + + +class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "step within the denoising loop that prepares the latent input for the denoiser. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "first_last_frame_latents", + required=True, + type_hint=torch.Tensor, + description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + block_state.latent_model_input = torch.cat( + [block_state.latents, block_state.first_last_frame_latents], dim=1 + ).to(block_state.dtype) + return components, block_state + + class WanLoopDenoiser(ModularPipelineBlocks): model_name = "wan" + def __init__( + self, + guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.1. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either: + + - A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + 'encoder_hidden_states'. + - A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + @property def expected_components(self) -> List[ComponentSpec]: return [ @@ -59,49 +199,30 @@ class WanLoopDenoiser(ModularPipelineBlocks): @property def inputs(self) -> List[Tuple[str, Any]]: - return [ + inputs = [ InputParam("attention_kwargs"), - ] - - @property - def intermediate_inputs(self) -> List[str]: - return [ - InputParam( - "latents", - required=True, - type_hint=torch.Tensor, - description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", - ), InputParam( "num_inference_steps", required=True, type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam( - kwargs_type="denoiser_input_fields", - description=( - "All conditional model inputs that need to be prepared with guider. " - "It should contain prompt_embeds/negative_prompt_embeds. " - "Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" - ), - ), ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs @torch.no_grad() def __call__( self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: - # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) - # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) - guider_inputs = { - "prompt_embeds": ( - getattr(block_state, "prompt_embeds", None), - getattr(block_state, "negative_prompt_embeds", None), - ), - } - transformer_dtype = components.transformer.dtype - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) # The guider splits model inputs into separate batches for conditional/unconditional predictions. @@ -112,22 +233,26 @@ class WanLoopDenoiser(ModularPipelineBlocks): # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch # ] # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). - guider_state = components.guider.prepare_inputs(guider_inputs) + guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) # run the denoiser for each guidance batch for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) - cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} - prompt_embeds = cond_kwargs.pop("prompt_embeds") + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( - hidden_states=block_state.latents.to(transformer_dtype), - timestep=t.flatten(), - encoder_hidden_states=prompt_embeds, + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), attention_kwargs=block_state.attention_kwargs, return_dict=False, + **cond_kwargs, )[0] components.guider.cleanup_models(components.transformer) @@ -137,6 +262,141 @@ class WanLoopDenoiser(ModularPipelineBlocks): return components, block_state +class Wan22LoopDenoiser(ModularPipelineBlocks): + model_name = "wan" + + def __init__( + self, + guider_input_fields: Dict[str, Any] = {"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds")}, + ): + """Initialize a denoiser block that calls the denoiser model. This block is used in Wan2.2. + + Args: + guider_input_fields: A dictionary that maps each argument expected by the denoiser model + (for example, "encoder_hidden_states") to data stored on `block_state`. The value can be either: + + - A tuple of strings. For instance, `{"encoder_hidden_states": ("prompt_embeds", + "negative_prompt_embeds")}` tells the guider to read `block_state.prompt_embeds` and + `block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of + `encoder_hidden_states`. + - A string. For example, `{"encoder_hidden_image": "image_embeds"}` makes the guider forward + `block_state.image_embeds` for both conditional and unconditional batches. + """ + if not isinstance(guider_input_fields, dict): + raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") + self._guider_input_fields = guider_input_fields + super().__init__() + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.0}), + default_creation_method="from_config", + ), + ComponentSpec( + "guider_2", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 3.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ComponentSpec("transformer_2", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [ + ConfigSpec( + name="boundary_ratio", + default=0.875, + description="The boundary ratio to divide the denoising loop into high noise and low noise stages.", + ), + ] + + @property + def inputs(self) -> List[Tuple[str, Any]]: + inputs = [ + InputParam("attention_kwargs"), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + guider_input_names = [] + for value in self._guider_input_fields.values(): + if isinstance(value, tuple): + guider_input_names.extend(value) + else: + guider_input_names.append(value) + + for name in guider_input_names: + inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) + return inputs + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps + if t >= boundary_timestep: + block_state.current_model = components.transformer + block_state.guider = components.guider + else: + block_state.current_model = components.transformer_2 + block_state.guider = components.guider_2 + + block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # The guider splits model inputs into separate batches for conditional/unconditional predictions. + # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: + # you will get a guider_state with two batches: + # guider_state = [ + # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch + # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch + # ] + # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). + guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + block_state.guider.prepare_models(block_state.current_model) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = { + k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + for k, v in cond_kwargs.items() + if k in self._guider_input_fields.keys() + } + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = block_state.current_model( + hidden_states=block_state.latent_model_input.to(block_state.dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + **cond_kwargs, + )[0] + block_state.guider.cleanup_models(block_state.current_model) + + # Perform guidance + block_state.noise_pred = block_state.guider(guider_state)[0] + + return components, block_state + + class WanLoopAfterDenoiser(ModularPipelineBlocks): model_name = "wan" @@ -154,20 +414,6 @@ class WanLoopAfterDenoiser(ModularPipelineBlocks): "object (e.g. `WanDenoiseLoopWrapper`)" ) - @property - def inputs(self) -> List[Tuple[str, Any]]: - return [] - - @property - def intermediate_inputs(self) -> List[str]: - return [ - InputParam("generator"), - ] - - @property - def intermediate_outputs(self) -> List[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] - @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): # Perform scheduler step using the predicted output @@ -198,18 +444,11 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): @property def loop_expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec( - "guider", - ClassifierFreeGuidance, - config=FrozenDict({"guidance_scale": 5.0}), - default_creation_method="from_config", - ), ComponentSpec("scheduler", UniPCMultistepScheduler), - ComponentSpec("transformer", WanTransformer3DModel), ] @property - def loop_intermediate_inputs(self) -> List[InputParam]: + def loop_inputs(self) -> List[InputParam]: return [ InputParam( "timesteps", @@ -248,7 +487,12 @@ class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): class WanDenoiseStep(WanDenoiseLoopWrapper): block_classes = [ - WanLoopDenoiser, + WanLoopBeforeDenoiser, + WanLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), WanLoopAfterDenoiser, ] block_names = ["before_denoiser", "denoiser", "after_denoiser"] @@ -259,7 +503,110 @@ class WanDenoiseStep(WanDenoiseLoopWrapper): "Denoise step that iteratively denoise the latents. \n" "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" " - `WanLoopDenoiser`\n" " - `WanLoopAfterDenoiser`\n" - "This block supports both text2vid tasks." + "This block supports text-to-video tasks for wan2.1." + ) + + +class Wan22DenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanLoopBeforeDenoiser`\n" + " - `Wan22LoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports text-to-video tasks for Wan2.2." + ) + + +class WanImage2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + WanLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_image": "image_embeds", + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks for wan2.1." + ) + + +class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanImage2VideoLoopBeforeDenoiser, + Wan22LoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanImage2VideoLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports image-to-video tasks for Wan2.2." + ) + + +class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanFLF2VLoopBeforeDenoiser, + WanLoopDenoiser( + guider_input_fields={ + "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), + "encoder_hidden_states_image": "image_embeds", + } + ), + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" + " - `WanFLF2VLoopBeforeDenoiser`\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports FLF2V tasks for wan2.1." ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py index cb2fc24238..dc49df8eab 100644 --- a/src/diffusers/modular_pipelines/wan/encoders.py +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -15,21 +15,29 @@ import html from typing import List, Optional, Union +import numpy as np +import PIL import regex as re import torch -from transformers import AutoTokenizer, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance -from ...utils import is_ftfy_available, logging +from ...image_processor import PipelineImageInput +from ...models import AutoencoderKLWan +from ...utils import is_ftfy_available, is_torchvision_available, logging +from ...video_processor import VideoProcessor from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline if is_ftfy_available(): import ftfy +if is_torchvision_available(): + from torchvision import transforms + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -51,6 +59,103 @@ def prompt_clean(text): return text +def get_t5_prompt_embeds( + text_encoder: UMT5EncoderModel, + tokenizer: AutoTokenizer, + prompt: Union[str, List[str]], + max_sequence_length: int, + device: torch.device, +): + dtype = text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + return prompt_embeds + + +def encode_image( + image: PipelineImageInput, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + device: Optional[torch.device] = None, +): + image = image_processor(images=image, return_tensors="pt").to(device) + image_embeds = image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def encode_vae_image( + video_tensor: torch.Tensor, + vae: AutoencoderKLWan, + generator: torch.Generator, + device: torch.device, + dtype: torch.dtype, + latent_channels: int = 16, +): + if not isinstance(video_tensor, torch.Tensor): + raise ValueError(f"Expected video_tensor to be a tensor, got {type(video_tensor)}.") + + if isinstance(generator, list) and len(generator) != video_tensor.shape[0]: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {video_tensor.shape[0]}." + ) + + video_tensor = video_tensor.to(device=device, dtype=dtype) + + if isinstance(generator, list): + video_latents = [ + retrieve_latents(vae.encode(video_tensor[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(video_tensor.shape[0]) + ] + video_latents = torch.cat(video_latents, dim=0) + else: + video_latents = retrieve_latents(vae.encode(video_tensor), sample_mode="argmax") + + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, latent_channels, 1, 1, 1) + .to(video_latents.device, video_latents.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, latent_channels, 1, 1, 1).to( + video_latents.device, video_latents.dtype + ) + video_latents = (video_latents - latents_mean) * latents_std + + return video_latents + + class WanTextEncoderStep(ModularPipelineBlocks): model_name = "wan" @@ -71,16 +176,12 @@ class WanTextEncoderStep(ModularPipelineBlocks): ), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [] - @property def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), InputParam("negative_prompt"), - InputParam("attention_kwargs"), + InputParam("max_sequence_length", default=512), ] @property @@ -107,47 +208,13 @@ class WanTextEncoderStep(ModularPipelineBlocks): ): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") - @staticmethod - def _get_t5_prompt_embeds( - components, - prompt: Union[str, List[str]], - max_sequence_length: int, - device: torch.device, - ): - dtype = components.text_encoder.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - - text_inputs = components.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - return prompt_embeds - @staticmethod def encode_prompt( components, prompt: str, device: Optional[torch.device] = None, - num_videos_per_prompt: int = 1, prepare_unconditional_embeds: bool = True, negative_prompt: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, ): r""" @@ -158,32 +225,29 @@ class WanTextEncoderStep(ModularPipelineBlocks): prompt to be encoded device: (`torch.device`): torch device - num_videos_per_prompt (`int`): - number of videos that should be generated per prompt prepare_unconditional_embeds (`bool`): whether to use prepare unconditional embeddings or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. max_sequence_length (`int`, defaults to `512`): The maximum number of text tokens to be used for the generation process. """ device = device or components._execution_device - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] + if not isinstance(prompt, list): + prompt = [prompt] + batch_size = len(prompt) - if prompt_embeds is None: - prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device) + prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + ) - if prepare_unconditional_embeds and negative_prompt_embeds is None: + if prepare_unconditional_embeds: negative_prompt = negative_prompt or "" negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt @@ -199,18 +263,14 @@ class WanTextEncoderStep(ModularPipelineBlocks): " the batch size of `prompt`." ) - negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds( - components, negative_prompt, max_sequence_length, device + negative_prompt_embeds = get_t5_prompt_embeds( + text_encoder=components.text_encoder, + tokenizer=components.tokenizer, + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, ) - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - - if prepare_unconditional_embeds: - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return prompt_embeds, negative_prompt_embeds @torch.no_grad() @@ -219,7 +279,6 @@ class WanTextEncoderStep(ModularPipelineBlocks): block_state = self.get_block_state(state) self.check_inputs(block_state) - block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 block_state.device = components._execution_device # Encode input prompt @@ -227,16 +286,382 @@ class WanTextEncoderStep(ModularPipelineBlocks): block_state.prompt_embeds, block_state.negative_prompt_embeds, ) = self.encode_prompt( - components, - block_state.prompt, - block_state.device, - 1, - block_state.prepare_unconditional_embeds, - block_state.negative_prompt, - prompt_embeds=None, - negative_prompt_embeds=None, + components=components, + prompt=block_state.prompt, + device=block_state.device, + prepare_unconditional_embeds=components.requires_unconditional_embeds, + negative_prompt=block_state.negative_prompt, + max_sequence_length=block_state.max_sequence_length, ) # Add outputs self.set_block_state(state, block_state) return components, state + + +class WanImageResizeStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Resize step that resize the image to the target area (height * width) while maintaining the aspect ratio." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("image", type_hint=PIL.Image.Image, required=True), + InputParam("height", type_hint=int, default=480), + InputParam("width", type_hint=int, default=832), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("resized_image", type_hint=PIL.Image.Image), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + max_area = block_state.height * block_state.width + + image = block_state.image + aspect_ratio = image.height / image.width + mod_value = components.vae_scale_factor_spatial * components.patch_size_spatial + block_state.height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + block_state.width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + block_state.resized_image = image.resize((block_state.width, block_state.height)) + + self.set_block_state(state, block_state) + return components, state + + +class WanImageCropResizeStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Resize step that resize the last_image to the same size of first frame image with center crop." + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam( + "resized_image", type_hint=PIL.Image.Image, required=True, description="The resized first frame image" + ), + InputParam("last_image", type_hint=PIL.Image.Image, required=True, description="The last frameimage"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("resized_last_image", type_hint=PIL.Image.Image), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + height = block_state.resized_image.height + width = block_state.resized_image.width + image = block_state.last_image + + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + resized_image = transforms.functional.center_crop(image, size) + block_state.resized_last_image = resized_image + + self.set_block_state(state, block_state) + return components, state + + +class WanImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step that generate image_embeds based on first frame image to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + image = block_state.resized_image + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=image, + device=device, + ) + block_state.image_embeds = image_embeds + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Image Encoder step that generate image_embeds based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("image_processor", CLIPImageProcessor), + ComponentSpec("image_encoder", CLIPVisionModel), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("image_embeds", type_hint=torch.Tensor, description="The image embeddings"), + ] + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + image_embeds = encode_image( + image_processor=components.image_processor, + image_encoder=components.image_encoder, + image=[first_frame_image, last_frame_image], + device=device, + ) + block_state.image_embeds = image_embeds + self.set_block_state(state, block_state) + return components, state + + +class WanVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on first frame image to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("num_frames"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "first_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first frame image condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + image = block_state.resized_image + + device = components._execution_device + dtype = torch.float32 + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + image_tensor = components.video_processor.preprocess(image, height=height, width=width).to( + device=device, dtype=dtype + ) + + if image_tensor.dim() == 4: + image_tensor = image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [ + image_tensor, + image_tensor.new_zeros(image_tensor.shape[0], image_tensor.shape[1], num_frames - 1, height, width), + ], + dim=2, + ).to(device=device, dtype=dtype) + + block_state.first_frame_latents = encode_vae_image( + video_tensor=video_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return "Vae Image Encoder step that generate condition_latents based on first and last frame images to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("resized_image", type_hint=PIL.Image.Image, required=True), + InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True), + InputParam("height"), + InputParam("width"), + InputParam("num_frames"), + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "first_last_frame_latents", + type_hint=torch.Tensor, + description="video latent representation with the first and last frame images condition", + ), + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + first_frame_image = block_state.resized_image + last_frame_image = block_state.resized_last_image + + device = components._execution_device + dtype = torch.float32 + + height = block_state.height or components.default_height + width = block_state.width or components.default_width + num_frames = block_state.num_frames or components.default_num_frames + + first_image_tensor = components.video_processor.preprocess(first_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) + first_image_tensor = first_image_tensor.unsqueeze(2) + + last_image_tensor = components.video_processor.preprocess(last_frame_image, height=height, width=width).to( + device=device, dtype=dtype + ) + + last_image_tensor = last_image_tensor.unsqueeze(2) + + video_tensor = torch.cat( + [ + first_image_tensor, + first_image_tensor.new_zeros( + first_image_tensor.shape[0], first_image_tensor.shape[1], num_frames - 2, height, width + ), + last_image_tensor, + ], + dim=2, + ).to(device=device, dtype=dtype) + + block_state.first_last_frame_latents = encode_vae_image( + video_tensor=video_tensor, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py index 5f4c1a9835..b3b70b2f9b 100644 --- a/src/diffusers/modular_pipelines/wan/modular_blocks.py +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -16,96 +16,244 @@ from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict from .before_denoise import ( - WanInputStep, + WanAdditionalInputsStep, + WanPrepareFirstFrameLatentsStep, + WanPrepareFirstLastFrameLatentsStep, WanPrepareLatentsStep, WanSetTimestepsStep, + WanTextInputStep, +) +from .decoders import WanImageVaeDecoderStep +from .denoise import ( + Wan22DenoiseStep, + Wan22Image2VideoDenoiseStep, + WanDenoiseStep, + WanFLF2VDenoiseStep, + WanImage2VideoDenoiseStep, +) +from .encoders import ( + WanFirstLastFrameImageEncoderStep, + WanFirstLastFrameVaeImageEncoderStep, + WanImageCropResizeStep, + WanImageEncoderStep, + WanImageResizeStep, + WanTextEncoderStep, + WanVaeImageEncoderStep, ) -from .decoders import WanDecodeStep -from .denoise import WanDenoiseStep -from .encoders import WanTextEncoderStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# before_denoise: text2vid -class WanBeforeDenoiseStep(SequentialPipelineBlocks): +# wan2.1 +# wan2.1: text2vid +class WanCoreDenoiseStep(SequentialPipelineBlocks): block_classes = [ - WanInputStep, + WanTextInputStep, WanSetTimestepsStep, WanPrepareLatentsStep, - ] - block_names = ["input", "set_timesteps", "prepare_latents"] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs for the denoise step.\n" - + "This is a sequential pipeline blocks:\n" - + " - `WanInputStep` is used to adjust the batch size of the model inputs\n" - + " - `WanSetTimestepsStep` is used to set the timesteps\n" - + " - `WanPrepareLatentsStep` is used to prepare the latents\n" - ) - - -# before_denoise: all task (text2vid,) -class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): - block_classes = [ - WanBeforeDenoiseStep, - ] - block_names = ["text2vid"] - block_trigger_inputs = [None] - - @property - def description(self): - return ( - "Before denoise step that prepare the inputs for the denoise step.\n" - + "This is an auto pipeline block that works for text2vid.\n" - + " - `WanBeforeDenoiseStep` (text2vid) is used.\n" - ) - - -# denoise: text2vid -class WanAutoDenoiseStep(AutoPipelineBlocks): - block_classes = [ WanDenoiseStep, ] - block_names = ["denoise"] - block_trigger_inputs = [None] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanDenoiseStep` is used to denoise the latents\n" + ) + + +# wan2.1: image2video +## image encoder +class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeStep, WanImageEncoderStep] + block_names = ["image_resize", "image_encoder"] + + @property + def description(self): + return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings" + + +## vae encoder +class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeStep, WanVaeImageEncoderStep] + block_names = ["image_resize", "vae_image_encoder"] + + @property + def description(self): + return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation" + + +## denoise +class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanPrepareFirstFrameLatentsStep, + WanImage2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "prepare_first_frame_latents", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" + + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" + ) + + +# wan2.1: FLF2v + + +## image encoder +class WanFLF2VImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "image_encoder"] + + @property + def description(self): + return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings" + + +## vae encoder +class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks): + model_name = "wan" + block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep] + block_names = ["image_resize", "last_image_resize", "vae_image_encoder"] + + @property + def description(self): + return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions" + + +## denoise +class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanPrepareFirstLastFrameLatentsStep, + WanFLF2VDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "prepare_first_last_frame_latents", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n" + + " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n" + ) + + +# wan2.1: auto blocks +## image encoder +class WanAutoImageEncoderStep(AutoPipelineBlocks): + block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep] + block_names = ["flf2v_image_encoder", "image2video_image_encoder"] + block_trigger_inputs = ["last_image", "image"] + + @property + def description(self): + return ( + "Image Encoder step that encode the image to generate the image embeddings" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +## vae encoder +class WanAutoVaeImageEncoderStep(AutoPipelineBlocks): + block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep] + block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"] + block_trigger_inputs = ["last_image", "image"] + + @property + def description(self): + return ( + "Vae Image Encoder step that encode the image to generate the image latents" + + "This is an auto pipeline block that works for image2video tasks." + + " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided." + + " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided." + + " - if `last_image` or `image` is not provided, step will be skipped." + ) + + +## denoise +class WanAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + WanFLF2VCoreDenoiseStep, + WanImage2VideoCoreDenoiseStep, + WanCoreDenoiseStep, + ] + block_names = ["flf2v", "image2video", "text2video"] + block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None] @property def description(self) -> str: return ( "Denoise step that iteratively denoise the latents. " - "This is a auto pipeline block that works for text2vid tasks.." - " - `WanDenoiseStep` (denoise) for text2vid tasks." + "This is a auto pipeline block that works for text2video and image2video tasks." + " - `WanCoreDenoiseStep` (text2video) for text2vid tasks." + " - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks." + + " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n" + + " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n" ) -# decode: all task (text2img, img2img, inpainting) -class WanAutoDecodeStep(AutoPipelineBlocks): - block_classes = [WanDecodeStep] - block_names = ["non-inpaint"] - block_trigger_inputs = [None] - - @property - def description(self): - return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`" - - -# text2vid +# auto pipeline blocks class WanAutoBlocks(SequentialPipelineBlocks): block_classes = [ WanTextEncoderStep, - WanAutoBeforeDenoiseStep, + WanAutoImageEncoderStep, + WanAutoVaeImageEncoderStep, WanAutoDenoiseStep, - WanAutoDecodeStep, + WanImageVaeDecoderStep, ] block_names = [ "text_encoder", - "before_denoise", + "image_encoder", + "vae_image_encoder", "denoise", - "decoder", + "decode", ] @property @@ -116,29 +264,211 @@ class WanAutoBlocks(SequentialPipelineBlocks): ) +# wan22 +# wan2.2: text2vid + + +## denoise +class Wan22CoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanTextInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + Wan22DenoiseStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents", "denoise"] + + @property + def description(self): + return ( + "denoise block that takes encoded conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n" + ) + + +# wan2.2: image2video +## denoise +class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanTextInputStep, + WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]), + WanSetTimestepsStep, + WanPrepareLatentsStep, + WanPrepareFirstFrameLatentsStep, + Wan22Image2VideoDenoiseStep, + ] + block_names = [ + "input", + "additional_inputs", + "set_timesteps", + "prepare_latents", + "prepare_first_frame_latents", + "denoise", + ] + + @property + def description(self): + return ( + "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + + " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n" + + " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n" + ) + + +class Wan22AutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + Wan22Image2VideoCoreDenoiseStep, + Wan22CoreDenoiseStep, + ] + block_names = ["image2video", "text2video"] + block_trigger_inputs = ["first_frame_latents", None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2video and image2video tasks." + " - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks." + " - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks." + + " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n" + + " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n" + ) + + +class Wan22AutoBlocks(SequentialPipelineBlocks): + block_classes = [ + WanTextEncoderStep, + WanAutoVaeImageEncoderStep, + Wan22AutoDenoiseStep, + WanImageVaeDecoderStep, + ] + block_names = [ + "text_encoder", + "vae_image_encoder", + "denoise", + "decode", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-video using Wan2.2.\n" + + "- for text-to-video generation, all you need to provide is `prompt`" + ) + + +# presets for wan2.1 and wan2.2 +# YiYi Notes: should we move these to doc? +# wan2.1 TEXT2VIDEO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), - ("input", WanInputStep), + ("input", WanTextInputStep), ("set_timesteps", WanSetTimestepsStep), ("prepare_latents", WanPrepareLatentsStep), ("denoise", WanDenoiseStep), - ("decode", WanDecodeStep), + ("decode", WanImageVaeDecoderStep), ] ) +IMAGE2VIDEO_BLOCKS = InsertableDict( + [ + ("image_resize", WanImageResizeStep), + ("image_encoder", WanImage2VideoImageEncoderStep), + ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("input", WanTextInputStep), + ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep), + ("denoise", WanImage2VideoDenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) + + +FLF2V_BLOCKS = InsertableDict( + [ + ("image_resize", WanImageResizeStep), + ("last_image_resize", WanImageCropResizeStep), + ("image_encoder", WanFLF2VImageEncoderStep), + ("vae_image_encoder", WanFLF2VVaeImageEncoderStep), + ("input", WanTextInputStep), + ("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep), + ("denoise", WanFLF2VDenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) AUTO_BLOCKS = InsertableDict( [ ("text_encoder", WanTextEncoderStep), - ("before_denoise", WanAutoBeforeDenoiseStep), + ("image_encoder", WanAutoImageEncoderStep), + ("vae_image_encoder", WanAutoVaeImageEncoderStep), ("denoise", WanAutoDenoiseStep), - ("decode", WanAutoDecodeStep), + ("decode", WanImageVaeDecoderStep), ] ) +# wan2.2 presets + +TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("input", WanTextInputStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", Wan22DenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) + +IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict( + [ + ("image_resize", WanImageResizeStep), + ("vae_image_encoder", WanImage2VideoVaeImageEncoderStep), + ("input", WanTextInputStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", Wan22DenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) + +AUTO_BLOCKS_WAN22 = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("vae_image_encoder", WanAutoVaeImageEncoderStep), + ("denoise", Wan22AutoDenoiseStep), + ("decode", WanImageVaeDecoderStep), + ] +) + +# presets all blocks (wan and wan22) + ALL_BLOCKS = { - "text2video": TEXT2VIDEO_BLOCKS, - "auto": AUTO_BLOCKS, + "wan2.1": { + "text2video": TEXT2VIDEO_BLOCKS, + "image2video": IMAGE2VIDEO_BLOCKS, + "flf2v": FLF2V_BLOCKS, + "auto": AUTO_BLOCKS, + }, + "wan2.2": { + "text2video": TEXT2VIDEO_BLOCKS_WAN22, + "image2video": IMAGE2VIDEO_BLOCKS_WAN22, + "auto": AUTO_BLOCKS_WAN22, + }, } diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index e4adf3d151..930b25e4b9 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -13,6 +13,8 @@ # limitations under the License. +from typing import Any, Dict, Optional + from ...loaders import WanLoraLoaderMixin from ...pipelines.pipeline_utils import StableDiffusionMixin from ...utils import logging @@ -35,6 +37,13 @@ class WanModularPipeline( default_blocks_name = "WanAutoBlocks" + # override the default_blocks_name in base class, which is just return self.default_blocks_name + def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]: + if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None: + return "Wan22AutoBlocks" + else: + return "WanAutoBlocks" + @property def default_height(self): return self.default_sample_height * self.vae_scale_factor_spatial @@ -59,6 +68,13 @@ class WanModularPipeline( def default_sample_num_frames(self): return 21 + @property + def patch_size_spatial(self): + patch_size_spatial = 2 + if hasattr(self, "transformer") and self.transformer is not None: + patch_size_spatial = self.transformer.config.patch_size[1] + return patch_size_spatial + @property def vae_scale_factor_spatial(self): vae_scale_factor = 8 @@ -86,3 +102,19 @@ class WanModularPipeline( if hasattr(self, "vae") and self.vae is not None: num_channels_latents = self.vae.config.z_dim return num_channels_latents + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds + + @property + def num_train_timesteps(self): + num_train_timesteps = 1000 + if hasattr(self, "scheduler") and self.scheduler is not None: + num_train_timesteps = self.scheduler.config.num_train_timesteps + return num_train_timesteps diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 8a32d4c367..044d854390 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -117,6 +117,7 @@ from .stable_diffusion_xl import ( StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline @@ -214,6 +215,24 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict( ] ) +AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", WanPipeline), + ] +) + +AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", WanImageToVideoPipeline), + ] +) + +AUTO_VIDEO2VIDEO_PIPELINES_MAPPING = OrderedDict( + [ + ("wan", WanVideoToVideoPipeline), + ] +) + _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict( [ ("kandinsky", KandinskyPipeline), @@ -247,6 +266,9 @@ SUPPORTED_TASKS_MAPPINGS = [ AUTO_TEXT2IMAGE_PIPELINES_MAPPING, AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, AUTO_INPAINT_PIPELINES_MAPPING, + AUTO_TEXT2VIDEO_PIPELINES_MAPPING, + AUTO_IMAGE2VIDEO_PIPELINES_MAPPING, + AUTO_VIDEO2VIDEO_PIPELINES_MAPPING, _AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING, _AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING, _AUTO_INPAINT_DECODER_PIPELINES_MAPPING, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a52fb46817..19f6c0f584 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -182,6 +182,21 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class Wan22AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class WanAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"]