mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
spatiotemporal guidance: additional wan registrations for attention and attention score skipping
This commit is contained in:
@@ -30,7 +30,7 @@ from ..models.transformers.transformer_hunyuan_video import (
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0, WanPAGAttnProcessor2_0, WanTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -101,6 +101,14 @@ def _register_attention_processors_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=WanAttnProcessor2_0,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_guidance_metadata():
|
||||
# CogView4
|
||||
@@ -111,6 +119,14 @@ def _register_guidance_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
GuidanceMetadataRegistry.register(
|
||||
model_class=WanAttnProcessor2_0,
|
||||
metadata=GuidanceMetadata(
|
||||
perturbed_attention_guidance_processor_cls=WanPAGAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
# CogVideoX
|
||||
@@ -217,6 +233,13 @@ def _register_transformer_blocks_metadata():
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
@@ -228,6 +251,7 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
|
||||
|
||||
|
||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||
_skip_proc_output_fn_Attention_WanAttnProcessor = _skip_attention___ret___hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user