diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index 9dabc7b286..d885b7326e 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -30,7 +30,7 @@ from ..models.transformers.transformer_hunyuan_video import ( ) from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock -from ..models.transformers.transformer_wan import WanTransformerBlock +from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock @dataclass @@ -101,6 +101,14 @@ def _register_attention_processors_metadata(): ), ) + # Wan + AttentionProcessorRegistry.register( + model_class=WanAttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor, + ), + ) + def _register_guidance_metadata(): # CogView4 @@ -111,6 +119,14 @@ def _register_guidance_metadata(): ), ) + # Wan + GuidanceMetadataRegistry.register( + model_class=WanAttnProcessor2_0, + metadata=GuidanceMetadata( + perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0, + ), + ) + def _register_transformer_blocks_metadata(): # CogVideoX @@ -217,6 +233,13 @@ def _register_transformer_blocks_metadata(): # fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) encoder_hidden_states = kwargs.get("encoder_hidden_states", None) @@ -228,6 +251,7 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, * _skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states +_skip_proc_output_fn_Attention_WanAttnProcessor = _skip_attention___ret___hidden_states def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):