mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -21,19 +21,67 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers, BaseOutput
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SanaBlockKvCache:
|
||||
vk: Optional[torch.Tensor] = None
|
||||
k_sum: Optional[torch.Tensor] = None
|
||||
temporal_cache: Optional[torch.Tensor] = None
|
||||
_enable_save: bool = False
|
||||
|
||||
def disable_save(self):
|
||||
self._enable_save = False
|
||||
|
||||
def enable_save(self):
|
||||
self._enable_save = True
|
||||
|
||||
def maybe_save(
|
||||
self,
|
||||
vk: Optional[torch.Tensor]=None,
|
||||
k_sum: Optional[torch.Tensor]=None,
|
||||
temporal_cache: Optional[torch.Tensor]=None,
|
||||
):
|
||||
if not self._enable_save:
|
||||
return
|
||||
|
||||
if vk is not None:
|
||||
self.vk = vk.detach().clone()
|
||||
if k_sum is not None:
|
||||
self.k_sum = k_sum.detach().clone()
|
||||
if temporal_cache is not None:
|
||||
self.temporal_cache = temporal_cache.detach().clone()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SanaVideoCausalTransformer3DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`SanaVideoCausalTransformer3DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_frames, height, width, num_channels)`):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input.
|
||||
kv_cache (`SanaKvCache`, *optional*):
|
||||
The KV cache for the transformer blocks.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
kv_caches: Optional[List[SanaBlockKvCache]] = None
|
||||
|
||||
|
||||
class CachedGLUMBConvTemp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -65,12 +113,11 @@ class CachedGLUMBConvTemp(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
|
||||
kv_cache: Optional[SanaBlockKvCache] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
|
||||
"""
|
||||
hidden_states: shape [B, T, H, W, C]
|
||||
kv_cache: list, with kv_cache[0/1/2] for optional cached states (only kv_cache[2] is used here for temporal)
|
||||
kv_cache: SanaBlockKvCache, with optional cached states (only temporal_cache is used here for temporal)
|
||||
"""
|
||||
|
||||
if self.residual_connection:
|
||||
@@ -99,17 +146,13 @@ class CachedGLUMBConvTemp(nn.Module):
|
||||
|
||||
# If using cache, prepend cached frames from last chunk along time axis (dim 2)
|
||||
if kv_cache is not None:
|
||||
if len(kv_cache) < 3:
|
||||
kv_cache.extend([None] * (3 - len(kv_cache)))
|
||||
if kv_cache[2] is not None:
|
||||
hidden_states_temporal_in = torch.cat([kv_cache[2], hidden_states_temporal], dim=2)
|
||||
padded_size = kv_cache[2].shape[2]
|
||||
if kv_cache.temporal_cache is not None:
|
||||
hidden_states_temporal_in = torch.cat([kv_cache.temporal_cache, hidden_states_temporal], dim=2)
|
||||
padded_size = kv_cache.temporal_cache.shape[2]
|
||||
# Save last padding_size frames for next chunk
|
||||
if save_kv_cache:
|
||||
kv_cache[2] = hidden_states_temporal[:, :, -padding_size:, :].detach().clone()
|
||||
else:
|
||||
if save_kv_cache:
|
||||
kv_cache = [None, None, hidden_states_temporal[:, :, -padding_size:, :].detach().clone()]
|
||||
kv_cache.maybe_save(
|
||||
temporal_cache=hidden_states_temporal[:, :, -padding_size:, :],
|
||||
)
|
||||
|
||||
t_conv_out = self.conv_temp(hidden_states_temporal_in)[:, :, padded_size:]
|
||||
hidden_states = hidden_states_temporal + t_conv_out
|
||||
@@ -121,9 +164,7 @@ class CachedGLUMBConvTemp(nn.Module):
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return hidden_states, kv_cache
|
||||
return hidden_states
|
||||
return hidden_states, kv_cache
|
||||
|
||||
|
||||
class SanaCausalLinearAttnProcessor1_0:
|
||||
@@ -139,9 +180,8 @@ class SanaCausalLinearAttnProcessor1_0:
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
) -> torch.Tensor:
|
||||
kv_cache: Optional[SanaBlockKvCache] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
@@ -205,14 +245,8 @@ class SanaCausalLinearAttnProcessor1_0:
|
||||
|
||||
# Handle KV cache for autoregressive generation
|
||||
if kv_cache is not None:
|
||||
cached_vk, cached_k_sum = kv_cache[0], kv_cache[1]
|
||||
|
||||
# Save current step's KV to cache if requested
|
||||
if save_kv_cache:
|
||||
kv_cache[0] = scores.detach().clone()
|
||||
kv_cache[1] = k_sum.detach().clone()
|
||||
|
||||
# Accumulate with previous cached values
|
||||
cached_vk, cached_k_sum = kv_cache.vk, kv_cache.k_sum
|
||||
kv_cache.maybe_save(vk=scores, k_sum=k_sum)
|
||||
if cached_vk is not None and cached_k_sum is not None:
|
||||
scores = scores + cached_vk
|
||||
k_sum = k_sum + cached_k_sum
|
||||
@@ -234,11 +268,7 @@ class SanaCausalLinearAttnProcessor1_0:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
# Return with cache if applicable
|
||||
if kv_cache is not None:
|
||||
return hidden_states, kv_cache
|
||||
|
||||
return hidden_states
|
||||
return hidden_states, kv_cache
|
||||
|
||||
|
||||
# Copied from transformers.transformer_sana_video.WanRotaryPosEmbed
|
||||
@@ -442,14 +472,10 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
mlp_ratio: float = 3.0,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
rope_max_seq_len: int = 1024,
|
||||
self_attn_processor: Optional[nn.Module] = None,
|
||||
ffn_processor: Optional[nn.Module] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention - must use causal linear attention
|
||||
if self_attn_processor is None:
|
||||
self_attn_processor = SanaCausalLinearAttnProcessor1_0()
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
@@ -460,7 +486,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
processor=self_attn_processor,
|
||||
processor=SanaCausalLinearAttnProcessor1_0(),
|
||||
)
|
||||
|
||||
# 2. Cross Attention
|
||||
@@ -480,9 +506,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
# 3. Feed-forward - must use cached conv
|
||||
if ffn_processor is None:
|
||||
ffn_processor = CachedGLUMBConvTemp
|
||||
self.ff = ffn_processor(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
||||
self.ff = CachedGLUMBConvTemp(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
@@ -497,9 +521,8 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
|
||||
kv_cache: Optional[SanaBlockKvCache] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# 1. Modulation
|
||||
@@ -513,17 +536,11 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
# Causal linear attention always supports kv_cache
|
||||
attn_result = self.attn1(
|
||||
attn_output, kv_cache = self.attn1(
|
||||
norm_hidden_states,
|
||||
rotary_emb=rotary_emb,
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
if isinstance(attn_result, tuple):
|
||||
attn_output, kv_cache = attn_result
|
||||
else:
|
||||
attn_output = attn_result
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention (no cache)
|
||||
@@ -542,22 +559,15 @@ class SanaVideoCausalTransformerBlock(nn.Module):
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
|
||||
|
||||
# Cached conv always supports kv_cache
|
||||
ff_result = self.ff(
|
||||
ff_output, kv_cache = self.ff(
|
||||
norm_hidden_states,
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
if isinstance(ff_result, tuple):
|
||||
ff_output, kv_cache = ff_result
|
||||
else:
|
||||
ff_output = ff_result
|
||||
|
||||
ff_output = ff_output.flatten(1, 3)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return hidden_states, kv_cache
|
||||
return hidden_states
|
||||
return hidden_states, kv_cache
|
||||
|
||||
|
||||
class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
@@ -667,8 +677,6 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qk_norm=qk_norm,
|
||||
self_attn_processor=SanaCausalLinearAttnProcessor1_0(),
|
||||
ffn_processor=CachedGLUMBConvTemp,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -690,11 +698,9 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
save_kv_cache: bool = False,
|
||||
kv_cache: Optional[list] = None,
|
||||
kv_caches: Optional[List[SanaBlockKvCache]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
) -> Union[Tuple[torch.Tensor, ...], SanaVideoCausalTransformer3DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
@@ -752,12 +758,12 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
# 2. Transformer blocks with KV cache
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# Note: gradient checkpointing doesn't support kv_cache (requires tuple return)
|
||||
if kv_cache is not None:
|
||||
if kv_caches is not None:
|
||||
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
|
||||
kv_cache = None
|
||||
kv_caches = None
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
hidden_states, _ = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
@@ -768,16 +774,14 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
kv_cache=None,
|
||||
)
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
|
||||
else:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
# Get kv_cache for this block if available
|
||||
block_kv_cache = kv_cache[index_block] if kv_cache is not None else None
|
||||
block_kv_cache = kv_caches[index_block] if kv_caches is not None else None
|
||||
|
||||
block_result = block(
|
||||
hidden_states, block_kv_cache = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
@@ -787,20 +791,12 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
rotary_emb,
|
||||
save_kv_cache=save_kv_cache,
|
||||
kv_cache=block_kv_cache,
|
||||
)
|
||||
|
||||
# Handle return value (could be tensor or tuple)
|
||||
if isinstance(block_result, tuple):
|
||||
hidden_states, updated_kv_cache = block_result
|
||||
if kv_cache is not None:
|
||||
kv_cache[index_block] = updated_kv_cache
|
||||
else:
|
||||
hidden_states = block_result
|
||||
|
||||
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
|
||||
if kv_caches is not None:
|
||||
kv_caches[index_block] = block_kv_cache
|
||||
|
||||
# 3. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
|
||||
@@ -819,10 +815,6 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return (output, kv_cache)
|
||||
return (output,)
|
||||
return (output, kv_caches)
|
||||
|
||||
if kv_cache is not None or save_kv_cache:
|
||||
return Transformer2DModelOutput(sample=output), kv_cache
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return SanaVideoCausalTransformer3DModelOutput(sample=output, kv_cache=kv_caches)
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, AutoencoderKLWan
|
||||
from ...models.transformers.transformer_sana_video_causal import SanaVideoCausalTransformer3DModel
|
||||
from ...models.transformers.transformer_sana_video_causal import SanaVideoCausalTransformer3DModel, SanaBlockKvCache
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -97,6 +97,77 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class LongSanaKvCache:
|
||||
def __init__(self, num_chunks: int, num_blocks: int):
|
||||
"""
|
||||
Initialize KV cache for all chunks.
|
||||
|
||||
Args:
|
||||
num_chunks: Number of chunks
|
||||
num_blocks: Number of transformer blocks
|
||||
|
||||
Returns:
|
||||
List of KV cache for each chunk
|
||||
"""
|
||||
kv_caches = []
|
||||
for _ in range(num_chunks):
|
||||
kv_caches.append([SanaBlockKvCache(vk=None, k_sum=None, temporal_cache=None) for _ in range(num_blocks)])
|
||||
self.num_chunks = num_chunks
|
||||
self.num_blocks = num_blocks
|
||||
self.kv_caches = kv_caches
|
||||
|
||||
def get_chunk_cache(self, chunk_idx: int) -> List[SanaBlockKvCache]:
|
||||
return self.kv_caches[chunk_idx]
|
||||
|
||||
def get_block_cache(self, chunk_idx: int, block_idx: int) -> SanaBlockKvCache:
|
||||
return self.kv_caches[chunk_idx][block_idx]
|
||||
|
||||
def update_chunk_cache(self, chunk_idx: int, chunk_kv_cache: List[SanaBlockKvCache]):
|
||||
self.kv_caches[chunk_idx] = chunk_kv_cache
|
||||
|
||||
def get_accumulated_chunk_cache(self, chunk_idx: int, num_cached_blocks: int = -1) -> List[SanaBlockKvCache]:
|
||||
"""
|
||||
Accumulate KV cache from previous chunks.
|
||||
|
||||
Args:
|
||||
chunk_idx: Current chunk index
|
||||
num_cached_blocks: Number of previous chunks to use for accumulation. -1 means use all previous chunks.
|
||||
|
||||
Returns:
|
||||
Accumulated KV cache for current chunk, a list of SanaBlockKvCache.
|
||||
"""
|
||||
if chunk_idx == 0:
|
||||
return self.kv_caches[0]
|
||||
|
||||
accumulated_kv_caches = [] # a list of SanaBlockKvCache
|
||||
for block_id in range(self.num_blocks):
|
||||
|
||||
start_chunk_idx = chunk_idx - num_cached_blocks if num_cached_blocks > 0 else 0
|
||||
# Initialize accumulated block cache, kv, k_sum, temporal cache are all None.
|
||||
acc_block_cache = SanaBlockKvCache(vk=None, k_sum=None, temporal_cache=None)
|
||||
# Accumulate spatial KV cache from previous chunks
|
||||
|
||||
for prev_chunk_idx in range(start_chunk_idx, chunk_idx):
|
||||
prev_kv_cache = self.kv_caches[prev_chunk_idx][block_id]
|
||||
|
||||
if prev_kv_cache.vk is None or prev_kv_cache.k_sum is None:
|
||||
continue
|
||||
|
||||
if acc_block_cache.vk is not None and acc_block_cache.k_sum is not None:
|
||||
acc_block_cache.vk += prev_kv_cache.vk
|
||||
acc_block_cache.k_sum += prev_kv_cache.k_sum
|
||||
else:
|
||||
# initialize the vk and k_sum using the first chunk's block cache.
|
||||
acc_block_cache.vk = prev_kv_cache.vk.clone()
|
||||
acc_block_cache.k_sum = prev_kv_cache.k_sum.clone()
|
||||
# copy the temporal cache from the previous chunk.
|
||||
acc_block_cache.temporal_cache = self.kv_caches[chunk_idx-1][block_id].temporal_cache
|
||||
|
||||
accumulated_kv_caches.append(acc_block_cache)
|
||||
|
||||
return accumulated_kv_caches
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -721,74 +792,6 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
chunk_indices.append(total_frames)
|
||||
return chunk_indices
|
||||
|
||||
def _initialize_kv_cache(self, num_chunks: int, num_blocks: int) -> List:
|
||||
"""
|
||||
Initialize KV cache for all chunks.
|
||||
|
||||
Args:
|
||||
num_chunks: Number of chunks
|
||||
num_blocks: Number of transformer blocks
|
||||
|
||||
Returns:
|
||||
List of KV cache for each chunk
|
||||
"""
|
||||
kv_cache = []
|
||||
for _ in range(num_chunks):
|
||||
kv_cache.append([[None, None, None] for _ in range(num_blocks)])
|
||||
return kv_cache
|
||||
|
||||
def _accumulate_kv_cache(self, kv_cache: List, chunk_idx: int, num_blocks: int):
|
||||
"""
|
||||
Accumulate KV cache from previous chunks.
|
||||
|
||||
Args:
|
||||
kv_cache: List of KV cache for all chunks
|
||||
chunk_idx: Current chunk index
|
||||
num_blocks: Number of transformer blocks
|
||||
|
||||
Returns:
|
||||
Accumulated KV cache for current chunk
|
||||
"""
|
||||
if chunk_idx == 0:
|
||||
return kv_cache[0]
|
||||
|
||||
cur_kv_cache = kv_cache[chunk_idx]
|
||||
for block_id in range(num_blocks):
|
||||
# Copy temporal cache from previous chunk
|
||||
cur_kv_cache[block_id][2] = kv_cache[chunk_idx - 1][block_id][2]
|
||||
|
||||
# Accumulate spatial KV cache from previous chunks
|
||||
cum_vk, cum_k_sum = None, None
|
||||
start_chunk_idx = chunk_idx - self.num_cached_blocks if self.num_cached_blocks > 0 else 0
|
||||
|
||||
for i in range(start_chunk_idx, chunk_idx):
|
||||
prev = kv_cache[i][block_id]
|
||||
if prev[0] is not None and prev[1] is not None:
|
||||
if cum_vk is None:
|
||||
cum_vk = prev[0].clone()
|
||||
cum_k_sum = prev[1].clone()
|
||||
else:
|
||||
cum_vk += prev[0]
|
||||
cum_k_sum += prev[1]
|
||||
|
||||
if chunk_idx > 0:
|
||||
assert cum_vk is not None and cum_k_sum is not None, "KV cache accumulation failed"
|
||||
|
||||
cur_kv_cache[block_id][0] = cum_vk
|
||||
cur_kv_cache[block_id][1] = cum_k_sum
|
||||
|
||||
return cur_kv_cache
|
||||
|
||||
def _get_num_transformer_blocks(self) -> int:
|
||||
"""Get the number of transformer blocks in the model."""
|
||||
if hasattr(self.transformer, "blocks"):
|
||||
return len(self.transformer.blocks)
|
||||
elif hasattr(self.transformer, "transformer_blocks"):
|
||||
return len(self.transformer.transformer_blocks)
|
||||
elif hasattr(self.transformer, "layers"):
|
||||
return len(self.transformer.layers)
|
||||
else:
|
||||
raise ValueError("Cannot determine number of transformer blocks")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
@@ -1062,10 +1065,10 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
num_chunks = len(chunk_indices) - 1
|
||||
|
||||
# Get number of transformer blocks
|
||||
num_blocks = self._get_num_transformer_blocks()
|
||||
num_blocks = self.transformer.config.num_layers
|
||||
|
||||
# Initialize KV cache for all chunks
|
||||
kv_cache = self._initialize_kv_cache(num_chunks, num_blocks)
|
||||
kv_cache = LongSanaKvCache(num_chunks=num_chunks, num_blocks=num_blocks)
|
||||
|
||||
# Output tensor to store denoised results
|
||||
output = torch.zeros_like(latents)
|
||||
@@ -1081,7 +1084,9 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
local_latent = latents[:, :, start_f:end_f].clone()
|
||||
|
||||
# Accumulate KV cache from previous chunks
|
||||
chunk_kv_cache = self._accumulate_kv_cache(kv_cache, chunk_idx, num_blocks)
|
||||
chunk_kv_cache = kv_cache.get_accumulated_chunk_cache(chunk_idx, num_cached_blocks=self.num_cached_blocks)
|
||||
for block_cache in chunk_kv_cache:
|
||||
block_cache.disable_save()
|
||||
|
||||
# Multi-step denoising for this chunk
|
||||
with self.progress_bar(total=len(denoising_step_list)) as progress_bar:
|
||||
@@ -1098,36 +1103,18 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
t = torch.tensor([current_timestep], device=device, dtype=torch.long)
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
# Forward pass through transformer with KV cache
|
||||
transformer_kwargs = {
|
||||
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
|
||||
"encoder_attention_mask": prompt_attention_mask,
|
||||
"timestep": timestep,
|
||||
"return_dict": False,
|
||||
"save_kv_cache": False, # Don't save during denoising steps
|
||||
"kv_cache": chunk_kv_cache, # Pass accumulated KV cache
|
||||
}
|
||||
|
||||
if self.attention_kwargs is not None:
|
||||
transformer_kwargs["attention_kwargs"] = self.attention_kwargs
|
||||
|
||||
# Predict flow
|
||||
model_output = self.transformer(
|
||||
flow_pred, _ = self.transformer(
|
||||
latent_model_input.to(dtype=transformer_dtype),
|
||||
**transformer_kwargs,
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep,
|
||||
return_dict=False,
|
||||
kv_caches=chunk_kv_cache,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)
|
||||
|
||||
# Handle different output formats
|
||||
if isinstance(model_output, tuple):
|
||||
if len(model_output) == 2:
|
||||
flow_pred, updated_kv_cache = model_output
|
||||
# Update chunk_kv_cache with new values
|
||||
if updated_kv_cache is not None:
|
||||
chunk_kv_cache = updated_kv_cache
|
||||
else:
|
||||
flow_pred = model_output[0]
|
||||
else:
|
||||
flow_pred = model_output
|
||||
|
||||
flow_pred = flow_pred.float()
|
||||
|
||||
@@ -1191,29 +1178,21 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
latent_for_cache = output[:, :, start_f:end_f]
|
||||
timestep_zero = torch.zeros(latent_for_cache.shape[0], device=device, dtype=torch.long)
|
||||
|
||||
cache_kwargs = {
|
||||
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
|
||||
"encoder_attention_mask": prompt_attention_mask,
|
||||
"timestep": timestep_zero,
|
||||
"return_dict": False,
|
||||
"save_kv_cache": True, # Enable saving during cache update
|
||||
"kv_cache": chunk_kv_cache,
|
||||
}
|
||||
|
||||
if self.attention_kwargs is not None:
|
||||
cache_kwargs["attention_kwargs"] = self.attention_kwargs
|
||||
for block_cache in chunk_kv_cache:
|
||||
block_cache.enable_save()
|
||||
|
||||
# Forward pass to update KV cache
|
||||
cache_output = self.transformer(
|
||||
_, chunk_kv_cache = self.transformer(
|
||||
latent_for_cache.to(dtype=transformer_dtype),
|
||||
**cache_kwargs,
|
||||
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
timestep=timestep_zero,
|
||||
return_dict=False,
|
||||
kv_caches=chunk_kv_cache,
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)
|
||||
|
||||
# Extract updated KV cache if returned
|
||||
if isinstance(cache_output, tuple) and len(cache_output) == 2:
|
||||
_, updated_kv_cache = cache_output
|
||||
if updated_kv_cache is not None:
|
||||
kv_cache[chunk_idx] = updated_kv_cache
|
||||
kv_cache.update_chunk_cache(chunk_idx, chunk_kv_cache)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
Reference in New Issue
Block a user