mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add gradient checkpointing and peft support
This commit is contained in:
@@ -22,6 +22,7 @@ import torch.nn.functional as F
|
||||
from torch import BoolTensor, IntTensor, Tensor, nn
|
||||
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
||||
flex_attention)
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
@@ -694,11 +695,12 @@ class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
|
||||
"""
|
||||
A 3D Diffusion Transformer model for video-like data.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
@@ -764,6 +766,7 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# Initialize output layer
|
||||
self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos):
|
||||
"""Prepare text embeddings and related components"""
|
||||
@@ -787,13 +790,20 @@ class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
def process_text_transformer_blocks(self, text_embed, time_embed, text_rope):
|
||||
"""Process text through transformer blocks"""
|
||||
for text_transformer_block in self.text_transformer_blocks:
|
||||
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope)
|
||||
else:
|
||||
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
|
||||
return text_embed
|
||||
|
||||
def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params):
|
||||
"""Process visual through transformer blocks"""
|
||||
for visual_transformer_block in self.visual_transformer_blocks:
|
||||
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed,
|
||||
visual_rope, sparse_params)
|
||||
else:
|
||||
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
|
||||
visual_rope, sparse_params)
|
||||
return visual_embed
|
||||
|
||||
|
||||
Reference in New Issue
Block a user