1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
yiyixuxu
2025-12-07 08:42:06 +01:00
parent 437cc9d734
commit 8b177ff9c6
2 changed files with 179 additions and 208 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union, List
import torch
import torch.nn.functional as F
@@ -21,19 +21,67 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers, BaseOutput
from ..attention import AttentionMixin
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from dataclasses import dataclass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class SanaBlockKvCache:
vk: Optional[torch.Tensor] = None
k_sum: Optional[torch.Tensor] = None
temporal_cache: Optional[torch.Tensor] = None
_enable_save: bool = False
def disable_save(self):
self._enable_save = False
def enable_save(self):
self._enable_save = True
def maybe_save(
self,
vk: Optional[torch.Tensor]=None,
k_sum: Optional[torch.Tensor]=None,
temporal_cache: Optional[torch.Tensor]=None,
):
if not self._enable_save:
return
if vk is not None:
self.vk = vk.detach().clone()
if k_sum is not None:
self.k_sum = k_sum.detach().clone()
if temporal_cache is not None:
self.temporal_cache = temporal_cache.detach().clone()
@dataclass
class SanaVideoCausalTransformer3DModelOutput(BaseOutput):
"""
The output of [`SanaVideoCausalTransformer3DModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_frames, height, width, num_channels)`):
The hidden states output conditioned on the `encoder_hidden_states` input.
kv_cache (`SanaKvCache`, *optional*):
The KV cache for the transformer blocks.
"""
sample: "torch.Tensor" # noqa: F821
kv_caches: Optional[List[SanaBlockKvCache]] = None
class CachedGLUMBConvTemp(nn.Module):
def __init__(
self,
@@ -65,12 +113,11 @@ class CachedGLUMBConvTemp(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
save_kv_cache: bool = False,
kv_cache: Optional[list] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
kv_cache: Optional[SanaBlockKvCache] = None,
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
"""
hidden_states: shape [B, T, H, W, C]
kv_cache: list, with kv_cache[0/1/2] for optional cached states (only kv_cache[2] is used here for temporal)
kv_cache: SanaBlockKvCache, with optional cached states (only temporal_cache is used here for temporal)
"""
if self.residual_connection:
@@ -99,17 +146,13 @@ class CachedGLUMBConvTemp(nn.Module):
# If using cache, prepend cached frames from last chunk along time axis (dim 2)
if kv_cache is not None:
if len(kv_cache) < 3:
kv_cache.extend([None] * (3 - len(kv_cache)))
if kv_cache[2] is not None:
hidden_states_temporal_in = torch.cat([kv_cache[2], hidden_states_temporal], dim=2)
padded_size = kv_cache[2].shape[2]
if kv_cache.temporal_cache is not None:
hidden_states_temporal_in = torch.cat([kv_cache.temporal_cache, hidden_states_temporal], dim=2)
padded_size = kv_cache.temporal_cache.shape[2]
# Save last padding_size frames for next chunk
if save_kv_cache:
kv_cache[2] = hidden_states_temporal[:, :, -padding_size:, :].detach().clone()
else:
if save_kv_cache:
kv_cache = [None, None, hidden_states_temporal[:, :, -padding_size:, :].detach().clone()]
kv_cache.maybe_save(
temporal_cache=hidden_states_temporal[:, :, -padding_size:, :],
)
t_conv_out = self.conv_temp(hidden_states_temporal_in)[:, :, padded_size:]
hidden_states = hidden_states_temporal + t_conv_out
@@ -121,9 +164,7 @@ class CachedGLUMBConvTemp(nn.Module):
if self.residual_connection:
hidden_states = hidden_states + residual
if kv_cache is not None or save_kv_cache:
return hidden_states, kv_cache
return hidden_states
return hidden_states, kv_cache
class SanaCausalLinearAttnProcessor1_0:
@@ -139,9 +180,8 @@ class SanaCausalLinearAttnProcessor1_0:
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
save_kv_cache: bool = False,
kv_cache: Optional[list] = None,
) -> torch.Tensor:
kv_cache: Optional[SanaBlockKvCache] = None,
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
original_dtype = hidden_states.dtype
if encoder_hidden_states is None:
@@ -205,14 +245,8 @@ class SanaCausalLinearAttnProcessor1_0:
# Handle KV cache for autoregressive generation
if kv_cache is not None:
cached_vk, cached_k_sum = kv_cache[0], kv_cache[1]
# Save current step's KV to cache if requested
if save_kv_cache:
kv_cache[0] = scores.detach().clone()
kv_cache[1] = k_sum.detach().clone()
# Accumulate with previous cached values
cached_vk, cached_k_sum = kv_cache.vk, kv_cache.k_sum
kv_cache.maybe_save(vk=scores, k_sum=k_sum)
if cached_vk is not None and cached_k_sum is not None:
scores = scores + cached_vk
k_sum = k_sum + cached_k_sum
@@ -234,11 +268,7 @@ class SanaCausalLinearAttnProcessor1_0:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
# Return with cache if applicable
if kv_cache is not None:
return hidden_states, kv_cache
return hidden_states
return hidden_states, kv_cache
# Copied from transformers.transformer_sana_video.WanRotaryPosEmbed
@@ -442,14 +472,10 @@ class SanaVideoCausalTransformerBlock(nn.Module):
mlp_ratio: float = 3.0,
qk_norm: Optional[str] = "rms_norm_across_heads",
rope_max_seq_len: int = 1024,
self_attn_processor: Optional[nn.Module] = None,
ffn_processor: Optional[nn.Module] = None,
) -> None:
super().__init__()
# 1. Self Attention - must use causal linear attention
if self_attn_processor is None:
self_attn_processor = SanaCausalLinearAttnProcessor1_0()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
@@ -460,7 +486,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
processor=self_attn_processor,
processor=SanaCausalLinearAttnProcessor1_0(),
)
# 2. Cross Attention
@@ -480,9 +506,7 @@ class SanaVideoCausalTransformerBlock(nn.Module):
)
# 3. Feed-forward - must use cached conv
if ffn_processor is None:
ffn_processor = CachedGLUMBConvTemp
self.ff = ffn_processor(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
self.ff = CachedGLUMBConvTemp(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
@@ -497,9 +521,8 @@ class SanaVideoCausalTransformerBlock(nn.Module):
height: int = None,
width: int = None,
rotary_emb: Optional[torch.Tensor] = None,
save_kv_cache: bool = False,
kv_cache: Optional[list] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]:
kv_cache: Optional[SanaBlockKvCache] = None,
) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]:
batch_size = hidden_states.shape[0]
# 1. Modulation
@@ -513,17 +536,11 @@ class SanaVideoCausalTransformerBlock(nn.Module):
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
# Causal linear attention always supports kv_cache
attn_result = self.attn1(
attn_output, kv_cache = self.attn1(
norm_hidden_states,
rotary_emb=rotary_emb,
save_kv_cache=save_kv_cache,
kv_cache=kv_cache,
)
if isinstance(attn_result, tuple):
attn_output, kv_cache = attn_result
else:
attn_output = attn_result
hidden_states = hidden_states + gate_msa * attn_output
# 3. Cross Attention (no cache)
@@ -542,22 +559,15 @@ class SanaVideoCausalTransformerBlock(nn.Module):
norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
# Cached conv always supports kv_cache
ff_result = self.ff(
ff_output, kv_cache = self.ff(
norm_hidden_states,
save_kv_cache=save_kv_cache,
kv_cache=kv_cache,
)
if isinstance(ff_result, tuple):
ff_output, kv_cache = ff_result
else:
ff_output = ff_result
ff_output = ff_output.flatten(1, 3)
hidden_states = hidden_states + gate_mlp * ff_output
if kv_cache is not None or save_kv_cache:
return hidden_states, kv_cache
return hidden_states
return hidden_states, kv_cache
class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
@@ -667,8 +677,6 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
qk_norm=qk_norm,
self_attn_processor=SanaCausalLinearAttnProcessor1_0(),
ffn_processor=CachedGLUMBConvTemp,
)
for _ in range(num_layers)
]
@@ -690,11 +698,9 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
save_kv_cache: bool = False,
kv_cache: Optional[list] = None,
kv_caches: Optional[List[SanaBlockKvCache]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
) -> Union[Tuple[torch.Tensor, ...], SanaVideoCausalTransformer3DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -752,12 +758,12 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
# 2. Transformer blocks with KV cache
if torch.is_grad_enabled() and self.gradient_checkpointing:
# Note: gradient checkpointing doesn't support kv_cache (requires tuple return)
if kv_cache is not None:
if kv_caches is not None:
logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.")
kv_cache = None
kv_caches = None
for index_block, block in enumerate(self.transformer_blocks):
hidden_states = self._gradient_checkpointing_func(
hidden_states, _ = self._gradient_checkpointing_func(
block,
hidden_states,
attention_mask,
@@ -768,16 +774,14 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
post_patch_height,
post_patch_width,
rotary_emb,
kv_cache=None,
)
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
else:
for index_block, block in enumerate(self.transformer_blocks):
# Get kv_cache for this block if available
block_kv_cache = kv_cache[index_block] if kv_cache is not None else None
block_kv_cache = kv_caches[index_block] if kv_caches is not None else None
block_result = block(
hidden_states, block_kv_cache = block(
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -787,20 +791,12 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
post_patch_height,
post_patch_width,
rotary_emb,
save_kv_cache=save_kv_cache,
kv_cache=block_kv_cache,
)
# Handle return value (could be tensor or tuple)
if isinstance(block_result, tuple):
hidden_states, updated_kv_cache = block_result
if kv_cache is not None:
kv_cache[index_block] = updated_kv_cache
else:
hidden_states = block_result
if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
if kv_caches is not None:
kv_caches[index_block] = block_kv_cache
# 3. Normalization
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
@@ -819,10 +815,6 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi
unscale_lora_layers(self, lora_scale)
if not return_dict:
if kv_cache is not None or save_kv_cache:
return (output, kv_cache)
return (output,)
return (output, kv_caches)
if kv_cache is not None or save_kv_cache:
return Transformer2DModelOutput(sample=output), kv_cache
return Transformer2DModelOutput(sample=output)
return SanaVideoCausalTransformer3DModelOutput(sample=output, kv_cache=kv_caches)

View File

@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLWan
from ...models.transformers.transformer_sana_video_causal import SanaVideoCausalTransformer3DModel
from ...models.transformers.transformer_sana_video_causal import SanaVideoCausalTransformer3DModel, SanaBlockKvCache
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -97,6 +97,77 @@ EXAMPLE_DOC_STRING = """
"""
class LongSanaKvCache:
def __init__(self, num_chunks: int, num_blocks: int):
"""
Initialize KV cache for all chunks.
Args:
num_chunks: Number of chunks
num_blocks: Number of transformer blocks
Returns:
List of KV cache for each chunk
"""
kv_caches = []
for _ in range(num_chunks):
kv_caches.append([SanaBlockKvCache(vk=None, k_sum=None, temporal_cache=None) for _ in range(num_blocks)])
self.num_chunks = num_chunks
self.num_blocks = num_blocks
self.kv_caches = kv_caches
def get_chunk_cache(self, chunk_idx: int) -> List[SanaBlockKvCache]:
return self.kv_caches[chunk_idx]
def get_block_cache(self, chunk_idx: int, block_idx: int) -> SanaBlockKvCache:
return self.kv_caches[chunk_idx][block_idx]
def update_chunk_cache(self, chunk_idx: int, chunk_kv_cache: List[SanaBlockKvCache]):
self.kv_caches[chunk_idx] = chunk_kv_cache
def get_accumulated_chunk_cache(self, chunk_idx: int, num_cached_blocks: int = -1) -> List[SanaBlockKvCache]:
"""
Accumulate KV cache from previous chunks.
Args:
chunk_idx: Current chunk index
num_cached_blocks: Number of previous chunks to use for accumulation. -1 means use all previous chunks.
Returns:
Accumulated KV cache for current chunk, a list of SanaBlockKvCache.
"""
if chunk_idx == 0:
return self.kv_caches[0]
accumulated_kv_caches = [] # a list of SanaBlockKvCache
for block_id in range(self.num_blocks):
start_chunk_idx = chunk_idx - num_cached_blocks if num_cached_blocks > 0 else 0
# Initialize accumulated block cache, kv, k_sum, temporal cache are all None.
acc_block_cache = SanaBlockKvCache(vk=None, k_sum=None, temporal_cache=None)
# Accumulate spatial KV cache from previous chunks
for prev_chunk_idx in range(start_chunk_idx, chunk_idx):
prev_kv_cache = self.kv_caches[prev_chunk_idx][block_id]
if prev_kv_cache.vk is None or prev_kv_cache.k_sum is None:
continue
if acc_block_cache.vk is not None and acc_block_cache.k_sum is not None:
acc_block_cache.vk += prev_kv_cache.vk
acc_block_cache.k_sum += prev_kv_cache.k_sum
else:
# initialize the vk and k_sum using the first chunk's block cache.
acc_block_cache.vk = prev_kv_cache.vk.clone()
acc_block_cache.k_sum = prev_kv_cache.k_sum.clone()
# copy the temporal cache from the previous chunk.
acc_block_cache.temporal_cache = self.kv_caches[chunk_idx-1][block_id].temporal_cache
accumulated_kv_caches.append(acc_block_cache)
return accumulated_kv_caches
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -721,74 +792,6 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
chunk_indices.append(total_frames)
return chunk_indices
def _initialize_kv_cache(self, num_chunks: int, num_blocks: int) -> List:
"""
Initialize KV cache for all chunks.
Args:
num_chunks: Number of chunks
num_blocks: Number of transformer blocks
Returns:
List of KV cache for each chunk
"""
kv_cache = []
for _ in range(num_chunks):
kv_cache.append([[None, None, None] for _ in range(num_blocks)])
return kv_cache
def _accumulate_kv_cache(self, kv_cache: List, chunk_idx: int, num_blocks: int):
"""
Accumulate KV cache from previous chunks.
Args:
kv_cache: List of KV cache for all chunks
chunk_idx: Current chunk index
num_blocks: Number of transformer blocks
Returns:
Accumulated KV cache for current chunk
"""
if chunk_idx == 0:
return kv_cache[0]
cur_kv_cache = kv_cache[chunk_idx]
for block_id in range(num_blocks):
# Copy temporal cache from previous chunk
cur_kv_cache[block_id][2] = kv_cache[chunk_idx - 1][block_id][2]
# Accumulate spatial KV cache from previous chunks
cum_vk, cum_k_sum = None, None
start_chunk_idx = chunk_idx - self.num_cached_blocks if self.num_cached_blocks > 0 else 0
for i in range(start_chunk_idx, chunk_idx):
prev = kv_cache[i][block_id]
if prev[0] is not None and prev[1] is not None:
if cum_vk is None:
cum_vk = prev[0].clone()
cum_k_sum = prev[1].clone()
else:
cum_vk += prev[0]
cum_k_sum += prev[1]
if chunk_idx > 0:
assert cum_vk is not None and cum_k_sum is not None, "KV cache accumulation failed"
cur_kv_cache[block_id][0] = cum_vk
cur_kv_cache[block_id][1] = cum_k_sum
return cur_kv_cache
def _get_num_transformer_blocks(self) -> int:
"""Get the number of transformer blocks in the model."""
if hasattr(self.transformer, "blocks"):
return len(self.transformer.blocks)
elif hasattr(self.transformer, "transformer_blocks"):
return len(self.transformer.transformer_blocks)
elif hasattr(self.transformer, "layers"):
return len(self.transformer.layers)
else:
raise ValueError("Cannot determine number of transformer blocks")
@property
def guidance_scale(self):
@@ -1062,10 +1065,10 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
num_chunks = len(chunk_indices) - 1
# Get number of transformer blocks
num_blocks = self._get_num_transformer_blocks()
num_blocks = self.transformer.config.num_layers
# Initialize KV cache for all chunks
kv_cache = self._initialize_kv_cache(num_chunks, num_blocks)
kv_cache = LongSanaKvCache(num_chunks=num_chunks, num_blocks=num_blocks)
# Output tensor to store denoised results
output = torch.zeros_like(latents)
@@ -1081,7 +1084,9 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
local_latent = latents[:, :, start_f:end_f].clone()
# Accumulate KV cache from previous chunks
chunk_kv_cache = self._accumulate_kv_cache(kv_cache, chunk_idx, num_blocks)
chunk_kv_cache = kv_cache.get_accumulated_chunk_cache(chunk_idx, num_cached_blocks=self.num_cached_blocks)
for block_cache in chunk_kv_cache:
block_cache.disable_save()
# Multi-step denoising for this chunk
with self.progress_bar(total=len(denoising_step_list)) as progress_bar:
@@ -1098,36 +1103,18 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
t = torch.tensor([current_timestep], device=device, dtype=torch.long)
timestep = t.expand(latent_model_input.shape[0])
# Forward pass through transformer with KV cache
transformer_kwargs = {
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
"encoder_attention_mask": prompt_attention_mask,
"timestep": timestep,
"return_dict": False,
"save_kv_cache": False, # Don't save during denoising steps
"kv_cache": chunk_kv_cache, # Pass accumulated KV cache
}
if self.attention_kwargs is not None:
transformer_kwargs["attention_kwargs"] = self.attention_kwargs
# Predict flow
model_output = self.transformer(
flow_pred, _ = self.transformer(
latent_model_input.to(dtype=transformer_dtype),
**transformer_kwargs,
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep,
return_dict=False,
kv_caches=chunk_kv_cache,
attention_kwargs=self.attention_kwargs,
)
# Handle different output formats
if isinstance(model_output, tuple):
if len(model_output) == 2:
flow_pred, updated_kv_cache = model_output
# Update chunk_kv_cache with new values
if updated_kv_cache is not None:
chunk_kv_cache = updated_kv_cache
else:
flow_pred = model_output[0]
else:
flow_pred = model_output
flow_pred = flow_pred.float()
@@ -1191,29 +1178,21 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latent_for_cache = output[:, :, start_f:end_f]
timestep_zero = torch.zeros(latent_for_cache.shape[0], device=device, dtype=torch.long)
cache_kwargs = {
"encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype),
"encoder_attention_mask": prompt_attention_mask,
"timestep": timestep_zero,
"return_dict": False,
"save_kv_cache": True, # Enable saving during cache update
"kv_cache": chunk_kv_cache,
}
if self.attention_kwargs is not None:
cache_kwargs["attention_kwargs"] = self.attention_kwargs
for block_cache in chunk_kv_cache:
block_cache.enable_save()
# Forward pass to update KV cache
cache_output = self.transformer(
_, chunk_kv_cache = self.transformer(
latent_for_cache.to(dtype=transformer_dtype),
**cache_kwargs,
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
encoder_attention_mask=prompt_attention_mask,
timestep=timestep_zero,
return_dict=False,
kv_caches=chunk_kv_cache,
attention_kwargs=self.attention_kwargs,
)
# Extract updated KV cache if returned
if isinstance(cache_output, tuple) and len(cache_output) == 2:
_, updated_kv_cache = cache_output
if updated_kv_cache is not None:
kv_cache[chunk_idx] = updated_kv_cache
kv_cache.update_chunk_cache(chunk_idx, chunk_kv_cache)
if XLA_AVAILABLE:
xm.mark_step()