From 7af80e9ffcf4daef408d0f1c99b115c70ae73756 Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 11:24:24 +0000 Subject: [PATCH] add gradient checkpointing and peft support --- .../transformers/transformer_kandinsky.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 01c9b258b7..6dec8d93ac 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from torch import BoolTensor, IntTensor, Tensor, nn from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, flex_attention) +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -694,11 +695,12 @@ class Kandinsky5TransformerDecoderBlock(nn.Module): return visual_embed -class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): +class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): """ A 3D Diffusion Transformer model for video-like data. """ - + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -764,6 +766,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): # Initialize output layer self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + self.gradient_checkpointing = False def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): """Prepare text embeddings and related components""" @@ -787,13 +790,20 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): """Process text through transformer blocks""" for text_transformer_block in self.text_transformer_blocks: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) + if torch.is_grad_enabled() and self.gradient_checkpointing: + text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope) + else: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) return text_embed def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): """Process visual through transformer blocks""" for visual_transformer_block in self.visual_transformer_blocks: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, + if torch.is_grad_enabled() and self.gradient_checkpointing: + visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed, + visual_rope, sparse_params) + else: + visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, visual_rope, sparse_params) return visual_embed