1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] add support for gradient checkpointing in transformer_2d (#5943)

add support for gradient checkpointing in transformer_2d
This commit is contained in:
Sayak Paul
2023-11-27 16:21:12 +05:30
committed by GitHub
parent 7d6f30e89b
commit 3f7c3511dc

View File

@@ -20,7 +20,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock
from .embeddings import CaptionProjection, PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
block,
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
timestep,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
**ckpt_kwargs,
)
else:
hidden_states = block(