mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support pass kwargs to sd3 custom attention processor (#9818)
* Support pass kwargs to sd3 custom attention processor --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
hidden_states, emb=temb
|
||||
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
if self.use_dual_attention:
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
||||
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
||||
hidden_states = hidden_states + attn_output2
|
||||
|
||||
|
||||
@@ -411,11 +411,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
joint_attention_kwargs,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
elif not is_skip:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
|
||||
Reference in New Issue
Block a user