1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Chore] remove all is from auraflow. (#8980)

remove all is from auraflow.
This commit is contained in:
Sayak Paul
2024-07-26 07:31:06 +05:30
committed by GitHub
parent 9b8c8605d1
commit 50e66f2f95

View File

@@ -138,14 +138,14 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999):
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
residual = hidden_states
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(hidden_states=norm_hidden_states, i=i)
attn_output = self.attn(hidden_states=norm_hidden_states)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -201,7 +201,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
residual = hidden_states
residual_context = encoder_hidden_states
@@ -214,7 +214,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
# Process attention outputs for the `hidden_states`.
@@ -366,7 +366,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)