1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix scale_shift_factor being on cpu for wan and ltx (#12347)

* wan fix scale_shift_factor being on cpu

* apply device cast to ltx transformer

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Vladimir Mandic
2025-10-04 23:53:38 -04:00
committed by sayakpaul
parent 9169e81609
commit 36059182f1
3 changed files with 7 additions and 5 deletions

View File

@@ -350,7 +350,9 @@ class LTXVideoTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa

View File

@@ -665,12 +665,12 @@ class WanTransformer3DModel(
# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the

View File

@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
self.scale_shift_table.to(temb.device) + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
@@ -359,7 +359,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
hidden_states = hidden_states + control_hint * scale
# 6. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the