mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Chore] remove all is from auraflow. (#8980)
remove all is from auraflow.
This commit is contained in:
@@ -138,14 +138,14 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
||||
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999):
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
|
||||
residual = hidden_states
|
||||
|
||||
# Norm + Projection.
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
# Attention.
|
||||
attn_output = self.attn(hidden_states=norm_hidden_states, i=i)
|
||||
attn_output = self.attn(hidden_states=norm_hidden_states)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
|
||||
@@ -201,7 +201,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
):
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
@@ -214,7 +214,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -366,7 +366,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
|
||||
|
||||
Reference in New Issue
Block a user