1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

bugfix: fix chrono-edit context parallel (#12660)

* bugfix: fix chrono-edit context parallel

* bugfix: fix chrono-edit context parallel

* Update src/diffusers/models/transformers/transformer_chronoedit.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/models/transformers/transformer_chronoedit.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Clean up comments in transformer_chronoedit.py

Removed unnecessary comments regarding parallelization in cross-attention.

* fix style

* fix qc

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
DefTruth
2025-11-24 19:06:53 +08:00
committed by GitHub
parent 544ba677dd
commit 354d35adb0

View File

@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
return key_img, value_img
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
class WanAttnProcessor:
_attention_backend = None
_parallel_config = None
@@ -137,7 +137,8 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12660
parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -150,7 +151,8 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12660
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel(
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
# Reference: https://github.com/huggingface/diffusers/pull/12660
# We need to disable the splitting of encoder_hidden_states because
# the image_encoder consistently generates 257 tokens for image_embed. This causes
# the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
# after concatenation—to be indivisible by the number of devices in the CP.
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}