1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

spatiotemporal guidance: additional wan registrations for attention and attention score skipping

This commit is contained in:
Aryan
2025-04-06 04:41:28 +02:00
parent 98fdabde9e
commit 0147a6eb27

View File

@@ -30,7 +30,7 @@ from ..models.transformers.transformer_hunyuan_video import (
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock
@dataclass
@@ -101,6 +101,14 @@ def _register_attention_processors_metadata():
),
)
# Wan
AttentionProcessorRegistry.register(
model_class=WanAttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor,
),
)
def _register_guidance_metadata():
# CogView4
@@ -111,6 +119,14 @@ def _register_guidance_metadata():
),
)
# Wan
GuidanceMetadataRegistry.register(
model_class=WanAttnProcessor2_0,
metadata=GuidanceMetadata(
perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0,
),
)
def _register_transformer_blocks_metadata():
# CogVideoX
@@ -217,6 +233,13 @@ def _register_transformer_blocks_metadata():
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
@@ -228,6 +251,7 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
_skip_proc_output_fn_Attention_WanAttnProcessor = _skip_attention___ret___hidden_states
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):