From a9cb08af398c9fe06d2d62bd12942458d5dba151 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 7 Nov 2025 22:30:13 +0800 Subject: [PATCH] fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled (#12562) * fix the crash in Wan-AI/Wan2.2-TI2V-5B-Diffusers if CP is enabled Signed-off-by: Wang, Yi * address review comment Signed-off-by: Wang, Yi A * refine Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi Signed-off-by: Wang, Yi A --- src/diffusers/hooks/context_parallel.py | 8 +++++--- src/diffusers/models/transformers/transformer_wan.py | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 915fe453b9..6491d17b4f 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -203,10 +203,12 @@ class ContextParallelSplitHook(ModelHook): def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: - raise ValueError( - f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + logger.warning_once( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied." ) - return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + return x + else: + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) class ContextParallelGatherHook(ModelHook): diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f..6f3993eb3f 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -555,6 +555,9 @@ class WanTransformer3DModel( "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + "": { + "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, } @register_to_config