diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 24abf54d6d..3aecc43f0f 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -20,7 +20,7 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate +from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version from .attention import BasicTransformerBlock from .embeddings import CaptionProjection, PatchEmbed from .lora import LoRACompatibleConv, LoRACompatibleLinear @@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self.gradient_checkpointing = False + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, hidden_states: torch.Tensor, @@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin): for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( - block, + create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, @@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): timestep, cross_attention_kwargs, class_labels, - use_reentrant=False, + **ckpt_kwargs, ) else: hidden_states = block(