diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 342373b4c1..89d51969ae 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -138,14 +138,14 @@ class AuraFlowSingleTransformerBlock(nn.Module): self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff = AuraFlowFeedForward(dim, dim * 4) - def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=9999): + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): residual = hidden_states # Norm + Projection. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) # Attention. - attn_output = self.attn(hidden_states=norm_hidden_states, i=i) + attn_output = self.attn(hidden_states=norm_hidden_states) # Process attention outputs for the `hidden_states`. hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) @@ -201,7 +201,7 @@ class AuraFlowJointTransformerBlock(nn.Module): self.ff_context = AuraFlowFeedForward(dim, dim * 4) def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, i=0 + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor ): residual = hidden_states residual_context = encoder_hidden_states @@ -214,7 +214,7 @@ class AuraFlowJointTransformerBlock(nn.Module): # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, i=i + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states ) # Process attention outputs for the `hidden_states`. @@ -366,7 +366,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, i=index_block + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb ) # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)