1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add gradient checkpointing and peft support

This commit is contained in:
leffff
2025-10-14 11:24:24 +00:00
parent e3a3e9d1b6
commit 7af80e9ffc

View File

@@ -22,6 +22,7 @@ import torch.nn.functional as F
from torch import BoolTensor, IntTensor, Tensor, nn
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
flex_attention)
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
@@ -694,11 +695,12 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
return visual_embed
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
"""
A 3D Diffusion Transformer model for video-like data.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -764,6 +766,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
# Initialize output layer
self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
self.gradient_checkpointing = False
def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos):
"""Prepare text embeddings and related components"""
@@ -787,13 +790,20 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
def process_text_transformer_blocks(self, text_embed, time_embed, text_rope):
"""Process text through transformer blocks"""
for text_transformer_block in self.text_transformer_blocks:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
if torch.is_grad_enabled() and self.gradient_checkpointing:
text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope)
else:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
return text_embed
def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params):
"""Process visual through transformer blocks"""
for visual_transformer_block in self.visual_transformer_blocks:
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
if torch.is_grad_enabled() and self.gradient_checkpointing:
visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed,
visual_rope, sparse_params)
else:
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
visual_rope, sparse_params)
return visual_embed