1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Support pass kwargs to sd3 custom attention processor (#9818)

* Support pass kwargs to sd3 custom attention processor


---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Qin Zhou
2024-12-18 15:58:33 +08:00
committed by GitHub
parent 88b015dc9f
commit 8eb73c872a
2 changed files with 15 additions and 4 deletions

View File

@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
self._chunk_dim = dim
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2

View File

@@ -411,11 +411,15 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states,
encoder_hidden_states,
temb,
joint_attention_kwargs,
**ckpt_kwargs,
)
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual