mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix for multi-GPU WAN inference (#10997)
Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs
Co-authored-by: Jimmy <39@🇺🇸.com>
This commit is contained in:
committed by
GitHub
parent
d87ce2cefc
commit
e7ffeae0a1
@@ -441,6 +441,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
# first device rather than the last device, which hidden_states ends up
|
||||
# on.
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user