From 7ec060d449e73ea6846b8d9105dd21aec0a27461 Mon Sep 17 00:00:00 2001 From: Carolinabanana <140120812+Carolinabanana@users.noreply.github.com> Date: Tue, 18 Jun 2024 22:36:23 +0100 Subject: [PATCH] Fix gradient checkpointing issue for Stable Diffusion 3 (#8542) Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_sd3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 4b159511e2..740b19bb53 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -306,7 +306,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states = torch.utils.checkpoint.checkpoint( + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states,