diff --git a/src/diffusers/models/transformers/transformer_chronoedit.py b/src/diffusers/models/transformers/transformer_chronoedit.py index 1156065846..79828b6464 100644 --- a/src/diffusers/models/transformers/transformer_chronoedit.py +++ b/src/diffusers/models/transformers/transformer_chronoedit.py @@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t return key_img, value_img -# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor +# modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor class WanAttnProcessor: _attention_backend = None _parallel_config = None @@ -137,7 +137,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12660 + parallel_config=None, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -150,7 +151,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12660 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel( "blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.*": { - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, + # Reference: https://github.com/huggingface/diffusers/pull/12660 + # We need to disable the splitting of encoder_hidden_states because + # the image_encoder consistently generates 257 tokens for image_embed. This causes + # the shape of encoder_hidden_states—whose token count is always 769 (512 + 257) + # after concatenation—to be indivisible by the number of devices in the CP. "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), }