diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 45ae6a8781..d11f6c2a5e 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -671,7 +671,7 @@ class ChromaTransformer2DModel( ) if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask=attention_mask + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask ) else: