diff --git a/src/diffusers/models/transformers/transformer_sana_video_causal.py b/src/diffusers/models/transformers/transformer_sana_video_causal.py index 233c4b67b9..3b27a34102 100644 --- a/src/diffusers/models/transformers/transformer_sana_video_causal.py +++ b/src/diffusers/models/transformers/transformer_sana_video_causal.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union, List import torch import torch.nn.functional as F @@ -21,19 +21,67 @@ from torch import nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers, BaseOutput from ..attention import AttentionMixin from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm +from dataclasses import dataclass + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +@dataclass +class SanaBlockKvCache: + vk: Optional[torch.Tensor] = None + k_sum: Optional[torch.Tensor] = None + temporal_cache: Optional[torch.Tensor] = None + _enable_save: bool = False + + def disable_save(self): + self._enable_save = False + + def enable_save(self): + self._enable_save = True + + def maybe_save( + self, + vk: Optional[torch.Tensor]=None, + k_sum: Optional[torch.Tensor]=None, + temporal_cache: Optional[torch.Tensor]=None, + ): + if not self._enable_save: + return + + if vk is not None: + self.vk = vk.detach().clone() + if k_sum is not None: + self.k_sum = k_sum.detach().clone() + if temporal_cache is not None: + self.temporal_cache = temporal_cache.detach().clone() + + +@dataclass +class SanaVideoCausalTransformer3DModelOutput(BaseOutput): + """ + The output of [`SanaVideoCausalTransformer3DModel`]. + + Args: + sample (`torch.Tensor` of shape `(batch_size, num_frames, height, width, num_channels)`): + The hidden states output conditioned on the `encoder_hidden_states` input. + kv_cache (`SanaKvCache`, *optional*): + The KV cache for the transformer blocks. + """ + + sample: "torch.Tensor" # noqa: F821 + kv_caches: Optional[List[SanaBlockKvCache]] = None + + class CachedGLUMBConvTemp(nn.Module): def __init__( self, @@ -65,12 +113,11 @@ class CachedGLUMBConvTemp(nn.Module): def forward( self, hidden_states: torch.Tensor, - save_kv_cache: bool = False, - kv_cache: Optional[list] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]: + kv_cache: Optional[SanaBlockKvCache] = None, + ) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]: """ hidden_states: shape [B, T, H, W, C] - kv_cache: list, with kv_cache[0/1/2] for optional cached states (only kv_cache[2] is used here for temporal) + kv_cache: SanaBlockKvCache, with optional cached states (only temporal_cache is used here for temporal) """ if self.residual_connection: @@ -99,17 +146,13 @@ class CachedGLUMBConvTemp(nn.Module): # If using cache, prepend cached frames from last chunk along time axis (dim 2) if kv_cache is not None: - if len(kv_cache) < 3: - kv_cache.extend([None] * (3 - len(kv_cache))) - if kv_cache[2] is not None: - hidden_states_temporal_in = torch.cat([kv_cache[2], hidden_states_temporal], dim=2) - padded_size = kv_cache[2].shape[2] + if kv_cache.temporal_cache is not None: + hidden_states_temporal_in = torch.cat([kv_cache.temporal_cache, hidden_states_temporal], dim=2) + padded_size = kv_cache.temporal_cache.shape[2] # Save last padding_size frames for next chunk - if save_kv_cache: - kv_cache[2] = hidden_states_temporal[:, :, -padding_size:, :].detach().clone() - else: - if save_kv_cache: - kv_cache = [None, None, hidden_states_temporal[:, :, -padding_size:, :].detach().clone()] + kv_cache.maybe_save( + temporal_cache=hidden_states_temporal[:, :, -padding_size:, :], + ) t_conv_out = self.conv_temp(hidden_states_temporal_in)[:, :, padded_size:] hidden_states = hidden_states_temporal + t_conv_out @@ -121,9 +164,7 @@ class CachedGLUMBConvTemp(nn.Module): if self.residual_connection: hidden_states = hidden_states + residual - if kv_cache is not None or save_kv_cache: - return hidden_states, kv_cache - return hidden_states + return hidden_states, kv_cache class SanaCausalLinearAttnProcessor1_0: @@ -139,9 +180,8 @@ class SanaCausalLinearAttnProcessor1_0: encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, rotary_emb: Optional[torch.Tensor] = None, - save_kv_cache: bool = False, - kv_cache: Optional[list] = None, - ) -> torch.Tensor: + kv_cache: Optional[SanaBlockKvCache] = None, + ) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]: original_dtype = hidden_states.dtype if encoder_hidden_states is None: @@ -205,14 +245,8 @@ class SanaCausalLinearAttnProcessor1_0: # Handle KV cache for autoregressive generation if kv_cache is not None: - cached_vk, cached_k_sum = kv_cache[0], kv_cache[1] - - # Save current step's KV to cache if requested - if save_kv_cache: - kv_cache[0] = scores.detach().clone() - kv_cache[1] = k_sum.detach().clone() - - # Accumulate with previous cached values + cached_vk, cached_k_sum = kv_cache.vk, kv_cache.k_sum + kv_cache.maybe_save(vk=scores, k_sum=k_sum) if cached_vk is not None and cached_k_sum is not None: scores = scores + cached_vk k_sum = k_sum + cached_k_sum @@ -234,11 +268,7 @@ class SanaCausalLinearAttnProcessor1_0: hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) - # Return with cache if applicable - if kv_cache is not None: - return hidden_states, kv_cache - - return hidden_states + return hidden_states, kv_cache # Copied from transformers.transformer_sana_video.WanRotaryPosEmbed @@ -442,14 +472,10 @@ class SanaVideoCausalTransformerBlock(nn.Module): mlp_ratio: float = 3.0, qk_norm: Optional[str] = "rms_norm_across_heads", rope_max_seq_len: int = 1024, - self_attn_processor: Optional[nn.Module] = None, - ffn_processor: Optional[nn.Module] = None, ) -> None: super().__init__() # 1. Self Attention - must use causal linear attention - if self_attn_processor is None: - self_attn_processor = SanaCausalLinearAttnProcessor1_0() self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps) self.attn1 = Attention( query_dim=dim, @@ -460,7 +486,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): dropout=dropout, bias=attention_bias, cross_attention_dim=None, - processor=self_attn_processor, + processor=SanaCausalLinearAttnProcessor1_0(), ) # 2. Cross Attention @@ -480,9 +506,7 @@ class SanaVideoCausalTransformerBlock(nn.Module): ) # 3. Feed-forward - must use cached conv - if ffn_processor is None: - ffn_processor = CachedGLUMBConvTemp - self.ff = ffn_processor(dim, dim, mlp_ratio, norm_type=None, residual_connection=False) + self.ff = CachedGLUMBConvTemp(dim, dim, mlp_ratio, norm_type=None, residual_connection=False) self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) @@ -497,9 +521,8 @@ class SanaVideoCausalTransformerBlock(nn.Module): height: int = None, width: int = None, rotary_emb: Optional[torch.Tensor] = None, - save_kv_cache: bool = False, - kv_cache: Optional[list] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, list]]: + kv_cache: Optional[SanaBlockKvCache] = None, + ) -> Tuple[torch.Tensor, Optional[SanaBlockKvCache]]: batch_size = hidden_states.shape[0] # 1. Modulation @@ -513,17 +536,11 @@ class SanaVideoCausalTransformerBlock(nn.Module): norm_hidden_states = norm_hidden_states.to(hidden_states.dtype) # Causal linear attention always supports kv_cache - attn_result = self.attn1( + attn_output, kv_cache = self.attn1( norm_hidden_states, rotary_emb=rotary_emb, - save_kv_cache=save_kv_cache, kv_cache=kv_cache, ) - if isinstance(attn_result, tuple): - attn_output, kv_cache = attn_result - else: - attn_output = attn_result - hidden_states = hidden_states + gate_msa * attn_output # 3. Cross Attention (no cache) @@ -542,22 +559,15 @@ class SanaVideoCausalTransformerBlock(nn.Module): norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width)) # Cached conv always supports kv_cache - ff_result = self.ff( + ff_output, kv_cache = self.ff( norm_hidden_states, - save_kv_cache=save_kv_cache, kv_cache=kv_cache, ) - if isinstance(ff_result, tuple): - ff_output, kv_cache = ff_result - else: - ff_output = ff_result ff_output = ff_output.flatten(1, 3) hidden_states = hidden_states + gate_mlp * ff_output - if kv_cache is not None or save_kv_cache: - return hidden_states, kv_cache - return hidden_states + return hidden_states, kv_cache class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin): @@ -667,8 +677,6 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi norm_eps=norm_eps, mlp_ratio=mlp_ratio, qk_norm=qk_norm, - self_attn_processor=SanaCausalLinearAttnProcessor1_0(), - ffn_processor=CachedGLUMBConvTemp, ) for _ in range(num_layers) ] @@ -690,11 +698,9 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi encoder_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None, - save_kv_cache: bool = False, - kv_cache: Optional[list] = None, + kv_caches: Optional[List[SanaBlockKvCache]] = None, return_dict: bool = True, - ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + ) -> Union[Tuple[torch.Tensor, ...], SanaVideoCausalTransformer3DModelOutput]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -752,12 +758,12 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi # 2. Transformer blocks with KV cache if torch.is_grad_enabled() and self.gradient_checkpointing: # Note: gradient checkpointing doesn't support kv_cache (requires tuple return) - if kv_cache is not None: + if kv_caches is not None: logger.warning("KV cache is not supported with gradient checkpointing. Disabling KV cache.") - kv_cache = None + kv_caches = None for index_block, block in enumerate(self.transformer_blocks): - hidden_states = self._gradient_checkpointing_func( + hidden_states, _ = self._gradient_checkpointing_func( block, hidden_states, attention_mask, @@ -768,16 +774,14 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi post_patch_height, post_patch_width, rotary_emb, + kv_cache=None, ) - if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): - hidden_states = hidden_states + controlnet_block_samples[index_block - 1] - else: for index_block, block in enumerate(self.transformer_blocks): # Get kv_cache for this block if available - block_kv_cache = kv_cache[index_block] if kv_cache is not None else None + block_kv_cache = kv_caches[index_block] if kv_caches is not None else None - block_result = block( + hidden_states, block_kv_cache = block( hidden_states, attention_mask, encoder_hidden_states, @@ -787,20 +791,12 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi post_patch_height, post_patch_width, rotary_emb, - save_kv_cache=save_kv_cache, kv_cache=block_kv_cache, ) # Handle return value (could be tensor or tuple) - if isinstance(block_result, tuple): - hidden_states, updated_kv_cache = block_result - if kv_cache is not None: - kv_cache[index_block] = updated_kv_cache - else: - hidden_states = block_result - - if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples): - hidden_states = hidden_states + controlnet_block_samples[index_block - 1] + if kv_caches is not None: + kv_caches[index_block] = block_kv_cache # 3. Normalization hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table) @@ -819,10 +815,6 @@ class SanaVideoCausalTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixi unscale_lora_layers(self, lora_scale) if not return_dict: - if kv_cache is not None or save_kv_cache: - return (output, kv_cache) - return (output,) + return (output, kv_caches) - if kv_cache is not None or save_kv_cache: - return Transformer2DModelOutput(sample=output), kv_cache - return Transformer2DModelOutput(sample=output) + return SanaVideoCausalTransformer3DModelOutput(sample=output, kv_cache=kv_caches) diff --git a/src/diffusers/pipelines/sana_video/pipeline_longsana.py b/src/diffusers/pipelines/sana_video/pipeline_longsana.py index 641849a81b..775af8c73c 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_longsana.py +++ b/src/diffusers/pipelines/sana_video/pipeline_longsana.py @@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import SanaLoraLoaderMixin from ...models import AutoencoderDC, AutoencoderKLWan -from ...models.transformers.transformer_sana_video_causal import SanaVideoCausalTransformer3DModel +from ...models.transformers.transformer_sana_video_causal import SanaVideoCausalTransformer3DModel, SanaBlockKvCache from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( BACKENDS_MAPPING, @@ -97,6 +97,77 @@ EXAMPLE_DOC_STRING = """ """ +class LongSanaKvCache: + def __init__(self, num_chunks: int, num_blocks: int): + """ + Initialize KV cache for all chunks. + + Args: + num_chunks: Number of chunks + num_blocks: Number of transformer blocks + + Returns: + List of KV cache for each chunk + """ + kv_caches = [] + for _ in range(num_chunks): + kv_caches.append([SanaBlockKvCache(vk=None, k_sum=None, temporal_cache=None) for _ in range(num_blocks)]) + self.num_chunks = num_chunks + self.num_blocks = num_blocks + self.kv_caches = kv_caches + + def get_chunk_cache(self, chunk_idx: int) -> List[SanaBlockKvCache]: + return self.kv_caches[chunk_idx] + + def get_block_cache(self, chunk_idx: int, block_idx: int) -> SanaBlockKvCache: + return self.kv_caches[chunk_idx][block_idx] + + def update_chunk_cache(self, chunk_idx: int, chunk_kv_cache: List[SanaBlockKvCache]): + self.kv_caches[chunk_idx] = chunk_kv_cache + + def get_accumulated_chunk_cache(self, chunk_idx: int, num_cached_blocks: int = -1) -> List[SanaBlockKvCache]: + """ + Accumulate KV cache from previous chunks. + + Args: + chunk_idx: Current chunk index + num_cached_blocks: Number of previous chunks to use for accumulation. -1 means use all previous chunks. + + Returns: + Accumulated KV cache for current chunk, a list of SanaBlockKvCache. + """ + if chunk_idx == 0: + return self.kv_caches[0] + + accumulated_kv_caches = [] # a list of SanaBlockKvCache + for block_id in range(self.num_blocks): + + start_chunk_idx = chunk_idx - num_cached_blocks if num_cached_blocks > 0 else 0 + # Initialize accumulated block cache, kv, k_sum, temporal cache are all None. + acc_block_cache = SanaBlockKvCache(vk=None, k_sum=None, temporal_cache=None) + # Accumulate spatial KV cache from previous chunks + + for prev_chunk_idx in range(start_chunk_idx, chunk_idx): + prev_kv_cache = self.kv_caches[prev_chunk_idx][block_id] + + if prev_kv_cache.vk is None or prev_kv_cache.k_sum is None: + continue + + if acc_block_cache.vk is not None and acc_block_cache.k_sum is not None: + acc_block_cache.vk += prev_kv_cache.vk + acc_block_cache.k_sum += prev_kv_cache.k_sum + else: + # initialize the vk and k_sum using the first chunk's block cache. + acc_block_cache.vk = prev_kv_cache.vk.clone() + acc_block_cache.k_sum = prev_kv_cache.k_sum.clone() + # copy the temporal cache from the previous chunk. + acc_block_cache.temporal_cache = self.kv_caches[chunk_idx-1][block_id].temporal_cache + + accumulated_kv_caches.append(acc_block_cache) + + return accumulated_kv_caches + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -721,74 +792,6 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): chunk_indices.append(total_frames) return chunk_indices - def _initialize_kv_cache(self, num_chunks: int, num_blocks: int) -> List: - """ - Initialize KV cache for all chunks. - - Args: - num_chunks: Number of chunks - num_blocks: Number of transformer blocks - - Returns: - List of KV cache for each chunk - """ - kv_cache = [] - for _ in range(num_chunks): - kv_cache.append([[None, None, None] for _ in range(num_blocks)]) - return kv_cache - - def _accumulate_kv_cache(self, kv_cache: List, chunk_idx: int, num_blocks: int): - """ - Accumulate KV cache from previous chunks. - - Args: - kv_cache: List of KV cache for all chunks - chunk_idx: Current chunk index - num_blocks: Number of transformer blocks - - Returns: - Accumulated KV cache for current chunk - """ - if chunk_idx == 0: - return kv_cache[0] - - cur_kv_cache = kv_cache[chunk_idx] - for block_id in range(num_blocks): - # Copy temporal cache from previous chunk - cur_kv_cache[block_id][2] = kv_cache[chunk_idx - 1][block_id][2] - - # Accumulate spatial KV cache from previous chunks - cum_vk, cum_k_sum = None, None - start_chunk_idx = chunk_idx - self.num_cached_blocks if self.num_cached_blocks > 0 else 0 - - for i in range(start_chunk_idx, chunk_idx): - prev = kv_cache[i][block_id] - if prev[0] is not None and prev[1] is not None: - if cum_vk is None: - cum_vk = prev[0].clone() - cum_k_sum = prev[1].clone() - else: - cum_vk += prev[0] - cum_k_sum += prev[1] - - if chunk_idx > 0: - assert cum_vk is not None and cum_k_sum is not None, "KV cache accumulation failed" - - cur_kv_cache[block_id][0] = cum_vk - cur_kv_cache[block_id][1] = cum_k_sum - - return cur_kv_cache - - def _get_num_transformer_blocks(self) -> int: - """Get the number of transformer blocks in the model.""" - if hasattr(self.transformer, "blocks"): - return len(self.transformer.blocks) - elif hasattr(self.transformer, "transformer_blocks"): - return len(self.transformer.transformer_blocks) - elif hasattr(self.transformer, "layers"): - return len(self.transformer.layers) - else: - raise ValueError("Cannot determine number of transformer blocks") @property def guidance_scale(self): @@ -1062,10 +1065,10 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): num_chunks = len(chunk_indices) - 1 # Get number of transformer blocks - num_blocks = self._get_num_transformer_blocks() + num_blocks = self.transformer.config.num_layers # Initialize KV cache for all chunks - kv_cache = self._initialize_kv_cache(num_chunks, num_blocks) + kv_cache = LongSanaKvCache(num_chunks=num_chunks, num_blocks=num_blocks) # Output tensor to store denoised results output = torch.zeros_like(latents) @@ -1081,7 +1084,9 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): local_latent = latents[:, :, start_f:end_f].clone() # Accumulate KV cache from previous chunks - chunk_kv_cache = self._accumulate_kv_cache(kv_cache, chunk_idx, num_blocks) + chunk_kv_cache = kv_cache.get_accumulated_chunk_cache(chunk_idx, num_cached_blocks=self.num_cached_blocks) + for block_cache in chunk_kv_cache: + block_cache.disable_save() # Multi-step denoising for this chunk with self.progress_bar(total=len(denoising_step_list)) as progress_bar: @@ -1098,36 +1103,18 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): t = torch.tensor([current_timestep], device=device, dtype=torch.long) timestep = t.expand(latent_model_input.shape[0]) - # Forward pass through transformer with KV cache - transformer_kwargs = { - "encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype), - "encoder_attention_mask": prompt_attention_mask, - "timestep": timestep, - "return_dict": False, - "save_kv_cache": False, # Don't save during denoising steps - "kv_cache": chunk_kv_cache, # Pass accumulated KV cache - } - - if self.attention_kwargs is not None: - transformer_kwargs["attention_kwargs"] = self.attention_kwargs # Predict flow - model_output = self.transformer( + flow_pred, _ = self.transformer( latent_model_input.to(dtype=transformer_dtype), - **transformer_kwargs, + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + kv_caches=chunk_kv_cache, + attention_kwargs=self.attention_kwargs, ) - # Handle different output formats - if isinstance(model_output, tuple): - if len(model_output) == 2: - flow_pred, updated_kv_cache = model_output - # Update chunk_kv_cache with new values - if updated_kv_cache is not None: - chunk_kv_cache = updated_kv_cache - else: - flow_pred = model_output[0] - else: - flow_pred = model_output flow_pred = flow_pred.float() @@ -1191,29 +1178,21 @@ class LongSanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin): latent_for_cache = output[:, :, start_f:end_f] timestep_zero = torch.zeros(latent_for_cache.shape[0], device=device, dtype=torch.long) - cache_kwargs = { - "encoder_hidden_states": prompt_embeds.to(dtype=transformer_dtype), - "encoder_attention_mask": prompt_attention_mask, - "timestep": timestep_zero, - "return_dict": False, - "save_kv_cache": True, # Enable saving during cache update - "kv_cache": chunk_kv_cache, - } - - if self.attention_kwargs is not None: - cache_kwargs["attention_kwargs"] = self.attention_kwargs + for block_cache in chunk_kv_cache: + block_cache.enable_save() # Forward pass to update KV cache - cache_output = self.transformer( + _, chunk_kv_cache = self.transformer( latent_for_cache.to(dtype=transformer_dtype), - **cache_kwargs, + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep_zero, + return_dict=False, + kv_caches=chunk_kv_cache, + attention_kwargs=self.attention_kwargs, ) - # Extract updated KV cache if returned - if isinstance(cache_output, tuple) and len(cache_output) == 2: - _, updated_kv_cache = cache_output - if updated_kv_cache is not None: - kv_cache[chunk_idx] = updated_kv_cache + kv_cache.update_chunk_cache(chunk_idx, chunk_kv_cache) if XLA_AVAILABLE: xm.mark_step()