From 7238d40dd9859dbcec7ac7ca87c9b13f3aea3558 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Sun, 16 Nov 2025 05:09:44 +0000 Subject: [PATCH] add stop_predicts (cooldown) --- src/diffusers/hooks/taylorseer_cache.py | 126 +++++++++++++++--------- 1 file changed, 79 insertions(+), 47 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 3c5d0a2f39..cb6b7fedd5 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,6 +1,6 @@ import torch from dataclasses import dataclass -from typing import Callable, Optional, List, Dict +from typing import Callable, Optional, List, Dict, Tuple from .hooks import ModelHook import math from ..models.attention import Attention @@ -12,23 +12,28 @@ from ..hooks import HookRegistry from ..utils import logging import re from collections import defaultdict + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" -SPECIAL_CACHE_IDENTIFIERS = { - "flux": [ - r"transformer_blocks\.\d+\.attn", - r"transformer_blocks\.\d+\.ff", - r"transformer_blocks\.\d+\.ff_context", - r"single_transformer_blocks\.\d+\.proj_out", - ] -} -SKIP_COMPUTE_IDENTIFIERS = { - "flux": [ - r"single_transformer_blocks\.\d+\.attn", - r"single_transformer_blocks\.\d+\.proj_mlp", - r"single_transformer_blocks\.\d+\.act_mlp", - ] +# Predefined cache templates for optimized architectures +_CACHE_TEMPLATES = { + "flux": { + "cache": [ + r"transformer_blocks\.\d+\.attn", + r"transformer_blocks\.\d+\.ff", + r"transformer_blocks\.\d+\.ff_context", + r"single_transformer_blocks\.\d+\.proj_out", + ], + "skip": [ + r"single_transformer_blocks\.\d+\.attn", + r"single_transformer_blocks\.\d+\.proj_mlp", + r"single_transformer_blocks\.\d+\.act_mlp", + ], + }, } @@ -41,24 +46,39 @@ class TaylorSeerCacheConfig: Attributes: warmup_steps (int, defaults to 3): Number of warmup steps without caching. predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. + stop_predicts (Optional[int], defaults to None): Step after which predictions are stopped and full computation is always performed. max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers. - skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. - special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache. + skip_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. + cache_identifiers (List[str], defaults to []): Identifiers for modules to cache. + + By default, this approximation can be applied to all attention modules, but in some architectures, where the outputs of attention modules are not used for any residual computation, we can skip this attention cache step, so we have to identify the next modules to cache. + Example: + ```python + ... + def forward(self, x: torch.Tensor) -> torch.Tensor: + attn_output = self.attention(x) # mark this attention module to skip computation + ffn_output = self.ffn(attn_output) # ffn_output will be cached + return ffn_output + ``` """ warmup_steps: int = 3 predict_steps: int = 5 + stop_predicts: Optional[int] = None max_order: int = 1 taylor_factors_dtype: torch.dtype = torch.float32 architecture: str | None = None - skip_compute_identifiers: List[str] = None - special_cache_identifiers: List[str] = None + skip_identifiers: List[str] = None + cache_identifiers: List[str] = None def __repr__(self) -> str: - return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})" + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, stop_predicts={self.stop_predicts}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_identifiers={self.skip_identifiers}, cache_identifiers={self.cache_identifiers})" + @classmethod + def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]: + return _CACHE_TEMPLATES class TaylorSeerOutputState: """ @@ -174,18 +194,20 @@ class TaylorSeerAttentionCacheHook(ModelHook): max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, - is_skip_compute: bool = False, + stop_predicts: Optional[int] = None, + is_skip: bool = False, ): super().__init__() self.module_name = module_name self.predict_steps = predict_steps self.max_order = max_order self.warmup_steps = warmup_steps + self.stop_predicts = stop_predicts self.step_counter = -1 self.states: Optional[List[TaylorSeerOutputState]] = None self.num_outputs: Optional[int] = None self.taylor_factors_dtype = taylor_factors_dtype - self.is_skip_compute = is_skip_compute + self.is_skip = is_skip def initialize_hook(self, module: torch.nn.Module): self.step_counter = -1 @@ -208,7 +230,7 @@ class TaylorSeerAttentionCacheHook(ModelHook): self.num_outputs = len(attention_outputs) self.states = [ TaylorSeerOutputState( - self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute + self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip ) for _ in range(self.num_outputs) ] @@ -218,22 +240,31 @@ class TaylorSeerAttentionCacheHook(ModelHook): ) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) - should_predict = self.states[0].remaining_predictions > 0 - if is_warmup_phase or not should_predict: - # Full compute during warmup or when refresh needed + if self.stop_predicts is not None and self.step_counter >= self.stop_predicts: + # After stop_predicts: always full compute without updating state attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] else: attention_outputs = list(attention_outputs) - is_first_update = self.step_counter == 0 # Only True for the very first step - for i, features in enumerate(attention_outputs): - self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) else: - # Predict using Taylor series - predicted_outputs = [state.predict(self.step_counter) for state in self.states] - return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) + should_predict = self.states[0].remaining_predictions > 0 + if is_warmup_phase or not should_predict: + # Full compute during warmup or when refresh needed + attention_outputs = self.fn_ref.original_forward(*args, **kwargs) + if isinstance(attention_outputs, torch.Tensor): + attention_outputs = [attention_outputs] + else: + attention_outputs = list(attention_outputs) + is_first_update = self.step_counter == 0 # Only True for the very first step + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) + else: + # Predict using Taylor series + predicted_outputs = [state.predict(self.step_counter) for state in self.states] + return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: self.states = None @@ -259,23 +290,23 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi >>> apply_taylorseer_cache(pipe.transformer, config) ``` """ - if config.skip_compute_identifiers: - skip_compute_identifiers = config.skip_compute_identifiers + if config.skip_identifiers: + skip_identifiers = config.skip_identifiers else: - skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + skip_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", []) - if config.special_cache_identifiers: - special_cache_identifiers = config.special_cache_identifiers + if config.cache_identifiers: + cache_identifiers = config.cache_identifiers else: - special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, []) + cache_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("cache", []) - logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}") - logger.debug(f"Special cache identifiers: {special_cache_identifiers}") + logger.debug(f"Skip identifiers: {skip_identifiers}") + logger.debug(f"Cache identifiers: {cache_identifiers}") for name, submodule in module.named_modules(): - if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers): - if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any( - re.fullmatch(identifier, name) for identifier in special_cache_identifiers + if (skip_identifiers and cache_identifiers) or (cache_identifiers): + if any(re.fullmatch(identifier, name) for identifier in skip_identifiers) or any( + re.fullmatch(identifier, name) for identifier in cache_identifiers ): logger.debug(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_hook(name, submodule, config) @@ -293,8 +324,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee config (TaylorSeerCacheConfig): Configuration for the cache. """ - is_skip_compute = any( - re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + is_skip = any( + re.fullmatch(identifier, name) for identifier in _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", []) ) registry = HookRegistry.check_if_exists_or_initialize(module) @@ -305,7 +336,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee config.max_order, config.warmup_steps, config.taylor_factors_dtype, - is_skip_compute=is_skip_compute, + stop_predicts=config.stop_predicts, + is_skip=is_skip, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file