mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] add support for gradient checkpointing in transformer_2d (#5943)
add support for gradient checkpointing in transformer_2d
This commit is contained in:
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import CaptionProjection, PatchEmbed
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
use_reentrant=False,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
|
||||
Reference in New Issue
Block a user