mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix scale_shift_factor being on cpu for wan and ltx (#12347)
* wan fix scale_shift_factor being on cpu * apply device cast to ltx transformer * Apply style fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
committed by
sayakpaul
parent
9169e81609
commit
36059182f1
@@ -350,7 +350,9 @@ class LTXVideoTransformerBlock(nn.Module):
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.size(1), num_ada_params, -1
|
||||
)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
|
||||
@@ -665,12 +665,12 @@ class WanTransformer3DModel(
|
||||
# 5. Output norm, projection & unpatchify
|
||||
if temb.ndim == 3:
|
||||
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift = shift.squeeze(2)
|
||||
scale = scale.squeeze(2)
|
||||
else:
|
||||
# batch_size, inner_dim
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
|
||||
@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
|
||||
control_hidden_states = control_hidden_states + hidden_states
|
||||
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
self.scale_shift_table.to(temb.device) + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
@@ -359,7 +359,7 @@ class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
|
||||
# 6. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
|
||||
Reference in New Issue
Block a user