From e7ffeae0a191f710881d1fbde00cd6ff025e81f2 Mon Sep 17 00:00:00 2001 From: "39th president of the United States, probably" <110263573+AmericanPresidentJimmyCarter@users.noreply.github.com> Date: Tue, 11 Mar 2025 13:42:12 -0400 Subject: [PATCH] Fix for multi-GPU WAN inference (#10997) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs Co-authored-by: Jimmy <39@🇺🇸.com> --- src/diffusers/models/transformers/transformer_wan.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 66cdda388c..4eb4add376 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -441,6 +441,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) hidden_states = self.proj_out(hidden_states)