diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 2b892a91ae..0b2e721b94 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset): def __init__( self, + args, instance_data_root, instance_prompt, class_prompt, @@ -980,10 +981,8 @@ class DreamBoothDataset(Dataset): class_num=None, size=1024, repeats=1, - center_crop=False, ): self.size = size - self.center_crop = center_crop self.instance_prompt = instance_prompt self.custom_instance_prompts = None @@ -1058,7 +1057,7 @@ class DreamBoothDataset(Dataset): if interpolation is None: raise ValueError(f"Unsupported interpolation mode {interpolation=}.") train_resize = transforms.Resize(size, interpolation=interpolation) - train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size) train_flip = transforms.RandomHorizontalFlip(p=1.0) train_transforms = transforms.Compose( [ @@ -1075,11 +1074,11 @@ class DreamBoothDataset(Dataset): # flip image = train_flip(image) if args.center_crop: - y1 = max(0, int(round((image.height - args.resolution) / 2.0))) - x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + y1 = max(0, int(round((image.height - self.size) / 2.0))) + x1 = max(0, int(round((image.width - self.size) / 2.0))) image = train_crop(image) else: - y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + y1, x1, h, w = train_crop.get_params(image, (self.size, self.size)) image = crop(image, y1, x1, h, w) image = train_transforms(image) self.pixel_values.append(image) @@ -1102,7 +1101,7 @@ class DreamBoothDataset(Dataset): self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=interpolation), - transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -1827,6 +1826,7 @@ def main(args): # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( + args=args, instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt, train_text_encoder_ti=args.train_text_encoder_ti, @@ -1836,7 +1836,6 @@ def main(args): class_num=args.num_class_images, size=args.resolution, repeats=args.repeats, - center_crop=args.center_crop, ) train_dataloader = torch.utils.data.DataLoader( diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 30d497892f..80c78b8a96 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -366,6 +366,8 @@ else: [ "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", + "WanAutoBlocks", + "WanModularPipeline", ] ) _import_structure["pipelines"].extend( @@ -999,6 +1001,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .modular_pipelines import ( StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, + WanAutoBlocks, + WanModularPipeline, ) from .pipelines import ( AllegroPipeline, diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 960d14e6fa..5fa047257f 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -107,6 +107,7 @@ class TransformerBlockRegistry: def _register_attention_processors_metadata(): from ..models.attention_processor import AttnProcessor2_0 from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor + from ..models.transformers.transformer_wan import WanAttnProcessor2_0 # AttnProcessor2_0 AttentionProcessorRegistry.register( @@ -124,6 +125,14 @@ def _register_attention_processors_metadata(): ), ) + # WanAttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=WanAttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0, + ), + ) + def _register_transformer_blocks_metadata(): from ..models.attention import BasicTransformerBlock @@ -261,4 +270,5 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * _skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states +_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states # fmt: on diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py index 14e6c2f888..0ce02e987d 100644 --- a/src/diffusers/hooks/layer_skip.py +++ b/src/diffusers/hooks/layer_skip.py @@ -91,10 +91,19 @@ class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode): if kwargs is None: kwargs = {} if func is torch.nn.functional.scaled_dot_product_attention: + query = kwargs.get("query", None) + key = kwargs.get("key", None) value = kwargs.get("value", None) - if value is None: - value = args[2] - return value + query = query if query is not None else args[0] + key = key if key is not None else args[1] + value = value if value is not None else args[2] + # If the Q sequence length does not match KV sequence length, methods like + # Perturbed Attention Guidance cannot be used (because the caller expects + # the same sequence length as Q, but if we return V here, it will not match). + # When Q.shape[2] != V.shape[2], PAG will essentially not be applied and + # the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale). + if query.shape[2] == value.shape[2]: + return value return func(*args, **kwargs) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 141a7fee85..c00ec7dd6e 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -38,18 +38,29 @@ from ..utils import ( from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS -logger = get_logger(__name__) # pylint: disable=invalid-name +_REQUIRED_FLASH_VERSION = "2.6.3" +_REQUIRED_SAGE_VERSION = "2.1.1" +_REQUIRED_FLEX_VERSION = "2.5.0" +_REQUIRED_XLA_VERSION = "2.2" +_REQUIRED_XFORMERS_VERSION = "0.0.29" + +_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION) +_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available() +_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION) +_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION) +_CAN_USE_NPU_ATTN = is_torch_npu_available() +_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) +_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) -if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): +if _CAN_USE_FLASH_ATTN: from flash_attn import flash_attn_func, flash_attn_varlen_func else: - logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") flash_attn_func = None flash_attn_varlen_func = None -if is_flash_attn_3_available(): +if _CAN_USE_FLASH_ATTN_3: from flash_attn_interface import flash_attn_func as flash_attn_3_func from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func else: @@ -57,7 +68,7 @@ else: flash_attn_3_varlen_func = None -if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): +if _CAN_USE_SAGE_ATTN: from sageattention import ( sageattn, sageattn_qk_int8_pv_fp8_cuda, @@ -67,9 +78,6 @@ if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): sageattn_varlen, ) else: - logger.warning( - "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." - ) sageattn = None sageattn_qk_int8_pv_fp16_cuda = None sageattn_qk_int8_pv_fp16_triton = None @@ -78,39 +86,39 @@ else: sageattn_varlen = None -if is_torch_version(">=", "2.5.0"): +if _CAN_USE_FLEX_ATTN: # We cannot import the flex_attention function from the package directly because it is expected (from the # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the # compiled function. import torch.nn.attention.flex_attention as flex_attention -if is_torch_npu_available(): +if _CAN_USE_NPU_ATTN: from torch_npu import npu_fusion_attention else: npu_fusion_attention = None -if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): +if _CAN_USE_XLA_ATTN: from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention else: xla_flash_attention = None -if is_xformers_available() and is_xformers_version(">=", "0.0.29"): +if _CAN_USE_XFORMERS_ATTN: import xformers.ops as xops else: - logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") xops = None +logger = get_logger(__name__) # pylint: disable=invalid-name + # TODO(aryan): Add support for the following: # - Sage Attention++ # - block sparse, radial and other attention methods # - CP with sage attention, flex, xformers, other missing backends # - Add support for normal and CP training with backends that don't support it yet - _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] @@ -179,13 +187,16 @@ class _AttentionBackendRegistry: @contextlib.contextmanager -def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): +def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ Context manager to set the active attention backend. """ if backend not in _AttentionBackendRegistry._backends: raise ValueError(f"Backend {backend} is not registered.") + backend = AttentionBackendName(backend) + _check_attention_backend_requirements(backend) + old_backend = _AttentionBackendRegistry._active_backend _AttentionBackendRegistry._active_backend = backend @@ -226,9 +237,10 @@ def dispatch_attention_fn( "dropout_p": dropout_p, "is_causal": is_causal, "scale": scale, - "enable_gqa": enable_gqa, **attention_kwargs, } + if is_torch_version(">=", "2.5.0"): + kwargs["enable_gqa"] = enable_gqa if _AttentionBackendRegistry._checks_enabled: removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) @@ -305,6 +317,57 @@ def _check_shape( # ===== Helper functions ===== +def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: + if not _CAN_USE_FLASH_ATTN: + raise RuntimeError( + f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`." + ) + + elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: + if not _CAN_USE_FLASH_ATTN_3: + raise RuntimeError( + f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." + ) + + elif backend in [ + AttentionBackendName.SAGE, + AttentionBackendName.SAGE_VARLEN, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + ]: + if not _CAN_USE_SAGE_ATTN: + raise RuntimeError( + f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`." + ) + + elif backend == AttentionBackendName.FLEX: + if not _CAN_USE_FLEX_ATTN: + raise RuntimeError( + f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`." + ) + + elif backend == AttentionBackendName._NATIVE_NPU: + if not _CAN_USE_NPU_ATTN: + raise RuntimeError( + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." + ) + + elif backend == AttentionBackendName._NATIVE_XLA: + if not _CAN_USE_XLA_ATTN: + raise RuntimeError( + f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`." + ) + + elif backend == AttentionBackendName.XFORMERS: + if not _CAN_USE_XFORMERS_ATTN: + raise RuntimeError( + f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." + ) + + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( batch_size: int, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index fb01e7e01a..4941b6d2a7 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -622,19 +622,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): attention as backend. """ from .attention import AttentionModuleMixin - from .attention_dispatch import AttentionBackendName + from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements # TODO: the following will not be required when everything is refactored to AttentionModuleMixin from .attention_processor import Attention, MochiAttention + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + backend = backend.lower() available_backends = {x.value for x in AttentionBackendName.__members__.values()} if backend not in available_backends: raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) - backend = AttentionBackendName(backend) - attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + _check_attention_backend_requirements(backend) + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): continue @@ -651,6 +653,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): from .attention import AttentionModuleMixin from .attention_processor import Attention, MochiAttention + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) for module in self.modules(): if not isinstance(module, attention_classes): diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 0f789d3961..736deb28c3 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -165,7 +165,7 @@ class UNet2DConditionModel( """ _supports_gradient_checkpointing = True - _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"] _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["BasicTransformerBlock"] diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index bf34eed28b..b3025bf4d3 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -40,6 +40,7 @@ else: "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] + _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -71,6 +72,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, ) + from .wan import WanAutoBlocks, WanModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 08e6d80fef..f48a227e2e 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -386,6 +386,7 @@ class ComponentsManager: id(component) is Python's built-in unique identifier for the object """ component_id = f"{name}_{id(component)}" + is_new_component = True # check for duplicated components for comp_id, comp in self.components.items(): @@ -394,6 +395,7 @@ class ComponentsManager: if comp_name == name: logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'") component_id = comp_id + is_new_component = False break else: logger.warning( @@ -426,7 +428,9 @@ class ComponentsManager: logger.warning( f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}" ) - self.remove(comp_id) + # remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager) + self.remove_from_collection(comp_id, collection) + self.collections[collection].add(component_id) logger.info( f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}" @@ -434,11 +438,29 @@ class ComponentsManager: else: logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'") - if self._auto_offload_enabled: + if self._auto_offload_enabled and is_new_component: self.enable_auto_cpu_offload(self._auto_offload_device) return component_id + def remove_from_collection(self, component_id: str, collection: str): + """ + Remove a component from a collection. + """ + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return + if component_id not in self.collections[collection]: + logger.warning(f"Component '{component_id}' not found in collection '{collection}'") + return + # remove from the collection + self.collections[collection].remove(component_id) + # check if this component is in any other collection + comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps] + if not comp_colls: # only if no other collection contains this component, remove it + logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager") + self.remove(component_id) + def remove(self, component_id: str = None): """ Remove a component from the ComponentsManager. diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 6056623d7f..ef2bfb494b 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -60,12 +60,14 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name MODULAR_PIPELINE_MAPPING = OrderedDict( [ ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ("wan", "WanModularPipeline"), ] ) MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict( [ ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), + ("WanModularPipeline", "WanAutoBlocks"), ] ) diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index b63925df26..f2fc015e94 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -185,6 +185,8 @@ class ComponentSpec: Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty segments). """ + if self.default_creation_method == "from_config": + return "null" parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] return "|".join(p for p in parts if p) diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py new file mode 100644 index 0000000000..7b548e003c --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/__init__.py @@ -0,0 +1,66 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["encoders"] = ["WanTextEncoderStep"] + _import_structure["modular_blocks"] = [ + "ALL_BLOCKS", + "AUTO_BLOCKS", + "TEXT2VIDEO_BLOCKS", + "WanAutoBeforeDenoiseStep", + "WanAutoBlocks", + "WanAutoBlocks", + "WanAutoDecodeStep", + "WanAutoDenoiseStep", + ] + _import_structure["modular_pipeline"] = ["WanModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .encoders import WanTextEncoderStep + from .modular_blocks import ( + ALL_BLOCKS, + AUTO_BLOCKS, + TEXT2VIDEO_BLOCKS, + WanAutoBeforeDenoiseStep, + WanAutoBlocks, + WanAutoDecodeStep, + WanAutoDenoiseStep, + ) + from .modular_pipeline import WanModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py new file mode 100644 index 0000000000..ef65b64537 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -0,0 +1,365 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Union + +import torch + +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that +# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by +# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the +# configuration of guider is. + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class WanInputStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Input processing step that:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n" + "All input tensors are expected to have either batch_size=1 or match the batch_size\n" + "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" + "have a final batch_size of batch_size * num_videos_per_prompt." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_videos_per_prompt", default=1), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings. Can be generated from text_encoder step.", + ), + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields + description="negative text embeddings used to guide the image generation", + ), + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None: + if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`" + f" {block_state.negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + if block_state.negative_prompt_embeds is not None: + _, seq_len, _ = block_state.negative_prompt_embeds.shape + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_videos_per_prompt, 1 + ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( + block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 + ) + + self.set_block_state(state, block_state) + + return components, state + + +class WanSetTimestepsStep(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps to perform at inference time", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + block_state.device, + block_state.timesteps, + block_state.sigmas, + ) + + self.set_block_state(state, block_state) + return components, state + + +class WanPrepareLatentsStep(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [] + + @property + def description(self) -> str: + return "Prepare latents step that prepares the latents for the text-to-video generation process" + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("num_frames", type_hint=int), + InputParam("latents", type_hint=Optional[torch.Tensor]), + InputParam("num_videos_per_prompt", type_hint=int, default=1), + ] + + @property + def intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam("generator"), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.", + ), + InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" + ) + ] + + @staticmethod + def check_inputs(components, block_state): + if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or ( + block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}." + ) + if block_state.num_frames is not None and ( + block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0 + ): + raise ValueError( + f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}." + ) + + @staticmethod + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp + def prepare_latents( + comp, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // comp.vae_scale_factor_spatial, + int(width) // comp.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + block_state.num_frames = block_state.num_frames or components.default_num_frames + block_state.device = components._execution_device + block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality + block_state.num_channels_latents = components.num_channels_latents + + self.check_inputs(components, block_state) + + block_state.latents = self.prepare_latents( + components, + block_state.batch_size * block_state.num_videos_per_prompt, + block_state.num_channels_latents, + block_state.height, + block_state.width, + block_state.num_frames, + block_state.dtype, + block_state.device, + block_state.generator, + block_state.latents, + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py new file mode 100644 index 0000000000..4fadeed4b9 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/decoders.py @@ -0,0 +1,105 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple, Union + +import numpy as np +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKLWan +from ...utils import logging +from ...video_processor import VideoProcessor +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanDecodeStep(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKLWan), + ComponentSpec( + "video_processor", + VideoProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into images" + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("output_type", default="pil"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents from the denoising step", + ) + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "videos", + type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]], + description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array", + ) + ] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae_dtype = components.vae.dtype + + if not block_state.output_type == "latent": + latents = block_state.latents + latents_mean = ( + torch.tensor(components.vae.config.latents_mean) + .view(1, components.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( + 1, components.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + latents = latents.to(vae_dtype) + block_state.videos = components.vae.decode(latents, return_dict=False)[0] + else: + block_state.videos = block_state.latents + + block_state.videos = components.video_processor.postprocess_video( + block_state.videos, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + + return components, state diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py new file mode 100644 index 0000000000..76c5cda5f9 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -0,0 +1,261 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Tuple + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import WanTransformer3DModel +from ...schedulers import UniPCMultistepScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + PipelineBlock, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanLoopDenoiser(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def description(self) -> str: + return ( + "Step within the denoising loop that denoise the latents with guidance. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [ + InputParam("attention_kwargs"), + ] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + kwargs_type="guider_input_fields", + description=( + "All conditional model inputs that need to be prepared with guider. " + "It should contain prompt_embeds/negative_prompt_embeds. " + "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state" + ), + ), + ] + + @torch.no_grad() + def __call__( + self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) + # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) + guider_input_fields = { + "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"), + } + transformer_dtype = components.transformer.dtype + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + + # Prepare mini‐batches according to guidance method and `guider_input_fields` + # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds. + # e.g. for CFG, we prepare two batches: one for uncond, one for cond + # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds + # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds + guider_state = components.guider.prepare_inputs(block_state, guider_input_fields) + + # run the denoiser for each guidance batch + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = guider_state_batch.as_dict() + cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields} + prompt_embeds = cond_kwargs.pop("prompt_embeds") + + # Predict the noise residual + # store the noise_pred in guider_state_batch so that we can apply guidance across all batches + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latents.to(transformer_dtype), + timestep=t.flatten(), + encoder_hidden_states=prompt_embeds, + attention_kwargs=block_state.attention_kwargs, + return_dict=False, + )[0] + components.guider.cleanup_models(components.transformer) + + # Perform guidance + block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state) + + return components, block_state + + +class WanLoopAfterDenoiser(PipelineBlock): + model_name = "wan" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", UniPCMultistepScheduler), + ] + + @property + def description(self) -> str: + return ( + "step within the denoising loop that update the latents. " + "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " + "object (e.g. `WanDenoiseLoopWrapper`)" + ) + + @property + def inputs(self) -> List[Tuple[str, Any]]: + return [] + + @property + def intermediate_inputs(self) -> List[str]: + return [ + InputParam("generator"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + # Perform scheduler step using the predicted output + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred.float(), + t, + block_state.latents.float(), + **block_state.scheduler_step_kwargs, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "wan" + + @property + def description(self) -> str: + return ( + "Pipeline block that iteratively denoise the latents over `timesteps`. " + "The specific steps with each iteration can be customized with `sub_blocks` attributes" + ) + + @property + def loop_expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ComponentSpec("scheduler", UniPCMultistepScheduler), + ComponentSpec("transformer", WanTransformer3DModel), + ] + + @property + def loop_intermediate_inputs(self) -> List[InputParam]: + return [ + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", + ), + ] + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + + return components, state + + +class WanDenoiseStep(WanDenoiseLoopWrapper): + block_classes = [ + WanLoopDenoiser, + WanLoopAfterDenoiser, + ] + block_names = ["before_denoiser", "denoiser", "after_denoiser"] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. \n" + "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n" + "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n" + " - `WanLoopDenoiser`\n" + " - `WanLoopAfterDenoiser`\n" + "This block supports both text2vid tasks." + ) diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py new file mode 100644 index 0000000000..b2ecfd1aa6 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/encoders.py @@ -0,0 +1,242 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import List, Optional, Union + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...utils import is_ftfy_available, logging +from ..modular_pipeline import PipelineBlock, PipelineState +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from .modular_pipeline import WanModularPipeline + + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class WanTextEncoderStep(PipelineBlock): + model_name = "wan" + + @property + def description(self) -> str: + return "Text Encoder step that generate text_embeddings to guide the video generation" + + @property + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("text_encoder", UMT5EncoderModel), + ComponentSpec("tokenizer", AutoTokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 5.0}), + default_creation_method="from_config", + ), + ] + + @property + def expected_configs(self) -> List[ConfigSpec]: + return [] + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("prompt"), + InputParam("negative_prompt"), + InputParam("attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="guider_input_fields", + description="negative text embeddings used to guide the image generation", + ), + ] + + @staticmethod + def check_inputs(block_state): + if block_state.prompt is not None and ( + not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) + ): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") + + @staticmethod + def _get_t5_prompt_embeds( + components, + prompt: Union[str, List[str]], + max_sequence_length: int, + device: torch.device, + ): + dtype = components.text_encoder.dtype + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + + text_inputs = components.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + return prompt_embeds + + @staticmethod + def encode_prompt( + components, + prompt: str, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + prepare_unconditional_embeds: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of videos that should be generated per prompt + prepare_unconditional_embeds (`bool`): + whether to use prepare unconditional embeddings or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, defaults to `512`): + The maximum number of text tokens to be used for the generation process. + """ + device = device or components._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device) + + if prepare_unconditional_embeds and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds( + components, negative_prompt, max_sequence_length, device + ) + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if prepare_unconditional_embeds: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds + + @torch.no_grad() + def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: + # Get inputs and intermediates + block_state = self.get_block_state(state) + self.check_inputs(block_state) + + block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 + block_state.device = components._execution_device + + # Encode input prompt + ( + block_state.prompt_embeds, + block_state.negative_prompt_embeds, + ) = self.encode_prompt( + components, + block_state.prompt, + block_state.device, + 1, + block_state.prepare_unconditional_embeds, + block_state.negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + ) + + # Add outputs + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py new file mode 100644 index 0000000000..5f4c1a9835 --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py @@ -0,0 +1,144 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict +from .before_denoise import ( + WanInputStep, + WanPrepareLatentsStep, + WanSetTimestepsStep, +) +from .decoders import WanDecodeStep +from .denoise import WanDenoiseStep +from .encoders import WanTextEncoderStep + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# before_denoise: text2vid +class WanBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = [ + WanInputStep, + WanSetTimestepsStep, + WanPrepareLatentsStep, + ] + block_names = ["input", "set_timesteps", "prepare_latents"] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is a sequential pipeline blocks:\n" + + " - `WanInputStep` is used to adjust the batch size of the model inputs\n" + + " - `WanSetTimestepsStep` is used to set the timesteps\n" + + " - `WanPrepareLatentsStep` is used to prepare the latents\n" + ) + + +# before_denoise: all task (text2vid,) +class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): + block_classes = [ + WanBeforeDenoiseStep, + ] + block_names = ["text2vid"] + block_trigger_inputs = [None] + + @property + def description(self): + return ( + "Before denoise step that prepare the inputs for the denoise step.\n" + + "This is an auto pipeline block that works for text2vid.\n" + + " - `WanBeforeDenoiseStep` (text2vid) is used.\n" + ) + + +# denoise: text2vid +class WanAutoDenoiseStep(AutoPipelineBlocks): + block_classes = [ + WanDenoiseStep, + ] + block_names = ["denoise"] + block_trigger_inputs = [None] + + @property + def description(self) -> str: + return ( + "Denoise step that iteratively denoise the latents. " + "This is a auto pipeline block that works for text2vid tasks.." + " - `WanDenoiseStep` (denoise) for text2vid tasks." + ) + + +# decode: all task (text2img, img2img, inpainting) +class WanAutoDecodeStep(AutoPipelineBlocks): + block_classes = [WanDecodeStep] + block_names = ["non-inpaint"] + block_trigger_inputs = [None] + + @property + def description(self): + return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`" + + +# text2vid +class WanAutoBlocks(SequentialPipelineBlocks): + block_classes = [ + WanTextEncoderStep, + WanAutoBeforeDenoiseStep, + WanAutoDenoiseStep, + WanAutoDecodeStep, + ] + block_names = [ + "text_encoder", + "before_denoise", + "denoise", + "decoder", + ] + + @property + def description(self): + return ( + "Auto Modular pipeline for text-to-video using Wan.\n" + + "- for text-to-video generation, all you need to provide is `prompt`" + ) + + +TEXT2VIDEO_BLOCKS = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("input", WanInputStep), + ("set_timesteps", WanSetTimestepsStep), + ("prepare_latents", WanPrepareLatentsStep), + ("denoise", WanDenoiseStep), + ("decode", WanDecodeStep), + ] +) + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", WanTextEncoderStep), + ("before_denoise", WanAutoBeforeDenoiseStep), + ("denoise", WanAutoDenoiseStep), + ("decode", WanAutoDecodeStep), + ] +) + + +ALL_BLOCKS = { + "text2video": TEXT2VIDEO_BLOCKS, + "auto": AUTO_BLOCKS, +} diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py new file mode 100644 index 0000000000..4d86e0d08e --- /dev/null +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -0,0 +1,90 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...loaders import WanLoraLoaderMixin +from ...pipelines.pipeline_utils import StableDiffusionMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class WanModularPipeline( + ModularPipeline, + StableDiffusionMixin, + WanLoraLoaderMixin, +): + """ + A ModularPipeline for Wan. + + + + This is an experimental feature and is likely to change in the future. + + + """ + + @property + def default_height(self): + return self.default_sample_height * self.vae_scale_factor_spatial + + @property + def default_width(self): + return self.default_sample_width * self.vae_scale_factor_spatial + + @property + def default_num_frames(self): + return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1 + + @property + def default_sample_height(self): + return 60 + + @property + def default_sample_width(self): + return 104 + + @property + def default_sample_num_frames(self): + return 21 + + @property + def vae_scale_factor_spatial(self): + vae_scale_factor = 8 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def vae_scale_factor_temporal(self): + vae_scale_factor = 4 + if hasattr(self, "vae") and self.vae is not None: + vae_scale_factor = 2 ** sum(self.vae.temperal_downsample) + return vae_scale_factor + + @property + def num_channels_transformer(self): + num_channels_transformer = 16 + if hasattr(self, "transformer") and self.transformer is not None: + num_channels_transformer = self.transformer.config.in_channels + return num_channels_transformer + + @property + def num_channels_latents(self): + num_channels_latents = 16 + if hasattr(self, "vae") and self.vae is not None: + num_channels_latents = self.vae.config.z_dim + return num_channels_latents diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index c74834ee82..3a34ec2a42 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -663,11 +663,11 @@ class ChromaPipeline( their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 9936608aaf..e169db4a4d 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -725,11 +725,11 @@ class ChromaImg2ImgPipeline( their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 5.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. strength (`float, *optional*, defaults to 0.9): Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will be used as a starting point, adding more noise to it the larger the strength. The number of denoising diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 6e6e5a4c7f..7211fb5693 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -674,7 +674,8 @@ class FluxPipeline( The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. true_cfg_scale (`float`, *optional*, defaults to 1.0): - When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -687,11 +688,11 @@ class FluxPipeline( their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index ea49821adc..5a057f94cf 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -661,11 +661,11 @@ class FluxControlPipeline( their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with prompt at the expense of lower image quality. + + Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index 94901ee0b6..3c78aeaf36 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -795,11 +795,11 @@ class FluxKontextPipeline( their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with prompt at the expense of lower image quality. + + Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py index 2b4abe8b24..6dc621901c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -989,7 +989,8 @@ class FluxKontextInpaintPipeline( The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. true_cfg_scale (`float`, *optional*, defaults to 1.0): - When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -1015,11 +1016,11 @@ class FluxKontextInpaintPipeline( their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index 341cdaf1e6..695f54f3d9 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -763,11 +763,11 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin): their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. guidance_scale (`float`, *optional*, defaults to 3.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 2cbb4af2b4..76b288ed0b 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -529,15 +529,14 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. true_cfg_scale (`float`, *optional*, defaults to 1.0): - When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. guidance_scale (`float`, defaults to `6.0`): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. Note that the only available - HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and - conditional latent is not applied. + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index fcf854a54c..e8f9d8368f 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -643,11 +643,11 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin): in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 4.5): - Guidance scale as defined in [Classifier-Free Diffusion - Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. - of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting - `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to - the text `prompt`, usually at the expense of lower image quality. + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. height (`int`, *optional*, defaults to self.unet.config.sample_size): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index dde5bbda60..7538635c80 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -32,6 +32,36 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) +class WanAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class WanModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AllegroPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 01dea057de..435bd32c60 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -75,7 +75,6 @@ from diffusers.utils.testing_utils import ( require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, - require_torch_gpu, require_torch_multi_accelerator, require_torch_version_greater, run_test_in_subprocess, @@ -1829,8 +1828,8 @@ class ModelTesterMixin: assert msg_substring in str(err_ctx.exception) - @parameterized.expand([0, "cuda", torch.device("cuda")]) - @require_torch_gpu + @parameterized.expand([0, torch_device, torch.device(torch_device)]) + @require_torch_accelerator def test_passing_non_dict_device_map_works(self, device_map): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).eval() @@ -1839,8 +1838,8 @@ class ModelTesterMixin: loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map) _ = loaded_model(**inputs_dict) - @parameterized.expand([("", "cuda"), ("", torch.device("cuda"))]) - @require_torch_gpu + @parameterized.expand([("", torch_device), ("", torch.device(torch_device))]) + @require_torch_accelerator def test_passing_dict_device_map_works(self, name, device): # There are other valid dict-based `device_map` values too. It's best to refer to # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap. @@ -1945,10 +1944,11 @@ class ModelPushToHubTester(unittest.TestCase): delete_repo(self.repo_id, token=TOKEN) -@require_torch_gpu +@require_torch_accelerator @require_torch_2 @is_torch_compile @slow +@require_torch_version_greater("2.7.1") class TorchCompileTesterMixin: different_shapes_for_compilation = None @@ -2013,7 +2013,7 @@ class TorchCompileTesterMixin: model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { - "onload_device": "cuda", + "onload_device": torch_device, "offload_device": "cpu", "offload_type": "block_level", "num_blocks_per_group": 1, @@ -2047,6 +2047,7 @@ class TorchCompileTesterMixin: @require_torch_accelerator @require_peft_backend @require_peft_version_greater("0.14.0") +@require_torch_version_greater("2.7.1") @is_torch_compile class LoraHotSwappingForModelTesterMixin: """Test that hotswapping does not result in recompilation on the model directly. diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index abf44aa744..123dff16f8 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -358,7 +358,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test model_class = UNet2DConditionModel main_input_name = "sample" # We override the items here because the unet under consideration is small. - model_split_percents = [0.5, 0.3, 0.4] + model_split_percents = [0.5, 0.34, 0.4] @property def dummy_input(self): diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index 842b9d19b3..fdb2d29835 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -15,7 +15,6 @@ import gc import unittest -import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel @@ -29,9 +28,7 @@ from diffusers.utils.testing_utils import ( ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import ( - PipelineTesterMixin, -) +from ..test_pipelines_common import PipelineTesterMixin enable_full_determinism() @@ -127,11 +124,15 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index 22dfef2eb0..6edc0cc882 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -14,7 +14,6 @@ import unittest -import numpy as np import torch from PIL import Image from transformers import ( @@ -147,11 +146,15 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 16, 16)) - expected_video = torch.randn(9, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): @@ -162,7 +165,25 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pass -class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests): +class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLWan( @@ -247,3 +268,32 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests): "output_type": "pt", } return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + + # fmt: off + expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass") + def test_inference_batch_single_identical(self): + pass diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py index 11c748424a..f4bb0960ac 100644 --- a/tests/pipelines/wan/test_wan_video_to_video.py +++ b/tests/pipelines/wan/test_wan_video_to_video.py @@ -14,7 +14,6 @@ import unittest -import numpy as np import torch from PIL import Image from transformers import AutoTokenizer, T5EncoderModel @@ -123,11 +122,15 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (17, 3, 16, 16)) - expected_video = torch.randn(17, 3, 16, 16) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195]) + # fmt:on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self):