From 354d35adb02e943d79014e5713290a4551d3dd01 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 24 Nov 2025 19:06:53 +0800 Subject: [PATCH] bugfix: fix chrono-edit context parallel (#12660) * bugfix: fix chrono-edit context parallel * bugfix: fix chrono-edit context parallel * Update src/diffusers/models/transformers/transformer_chronoedit.py Co-authored-by: Dhruv Nair * Update src/diffusers/models/transformers/transformer_chronoedit.py Co-authored-by: Dhruv Nair * Clean up comments in transformer_chronoedit.py Removed unnecessary comments regarding parallelization in cross-attention. * fix style * fix qc --------- Co-authored-by: Dhruv Nair --- .../transformers/transformer_chronoedit.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chronoedit.py b/src/diffusers/models/transformers/transformer_chronoedit.py index 1156065846..79828b6464 100644 --- a/src/diffusers/models/transformers/transformer_chronoedit.py +++ b/src/diffusers/models/transformers/transformer_chronoedit.py @@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t return key_img, value_img -# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor +# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor class WanAttnProcessor: _attention_backend = None _parallel_config = None @@ -137,7 +137,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12660 + parallel_config=None, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -150,7 +151,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12660 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel( "blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.*": { - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, + # Reference: https://github.com/huggingface/diffusers/pull/12660 + # We need to disable the splitting of encoder_hidden_states because + # the image_encoder consistently generates 257 tokens for image_embed. This causes + # the shape of encoder_hidden_states—whose token count is always 769 (512 + 257) + # after concatenation—to be indivisible by the number of devices in the CP. "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), }