diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 807e15558f..f400576fed 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -4,36 +4,25 @@ from dataclasses import dataclass from typing import Optional, List, Dict, Tuple import torch +import torch.nn as nn -from .hooks import ModelHook -from ..models.attention import Attention -from ..models.attention import AttentionModuleMixin -from ._common import _ATTENTION_CLASSES -from ..hooks import HookRegistry +from .hooks import ModelHook, StateManager, HookRegistry from ..utils import logging logger = logging.get_logger(__name__) - -_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" - -# Predefined cache templates for optimized architectures -_CACHE_TEMPLATES: Dict[str, Dict[str, List[str]]] = { - "flux": { - "cache": [ - r"transformer_blocks\.\d+\.attn", - r"transformer_blocks\.\d+\.ff", - r"transformer_blocks\.\d+\.ff_context", - r"single_transformer_blocks\.\d+\.proj_out", - ], - "skip": [ - r"single_transformer_blocks\.\d+\.attn", - r"single_transformer_blocks\.\d+\.proj_mlp", - r"single_transformer_blocks\.\d+\.act_mlp", - ], - }, -} - +_TAYLORSEER_CACHE_HOOK = "taylorseer_cache" +_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( + "^blocks.*attn", + "^transformer_blocks.*attn", + "^single_transformer_blocks.*attn", +) +_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) +_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS +_BLOCK_IDENTIFIERS = ( + "^[^.]*block[^.]*\\.[^.]+$", +) +_PROJ_OUT_IDENTIFIERS = ("^proj_out$",) @dataclass class TaylorSeerCacheConfig: @@ -43,37 +32,29 @@ class TaylorSeerCacheConfig: Attributes: warmup_steps (`int`, defaults to `3`): - Number of *outer* diffusion steps to run with full computation + Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor series factors are still updated, but no predictions are used. predict_steps (`int`, defaults to `5`): Number of prediction (cached) steps to take between two full computations. That is, once a module state is refreshed, it will - be reused for `predict_steps` subsequent outer steps, then a new + be reused for `predict_steps` subsequent denoising steps, then a new full forward will be computed on the next step. stop_predicts (`int`, *optional*, defaults to `None`): - Outer diffusion step index at which caching is disabled. - If provided, for `true_step >= stop_predicts` all modules are + Denoising step index at which caching is disabled. + If provided, for `self.current_step >= stop_predicts` all modules are evaluated normally (no predictions, no state updates). max_order (`int`, defaults to `1`): Maximum order of Taylor series expansion to approximate the features. Higher order gives closer approximation but more compute. - num_inner_loops (`int`, defaults to `1`): - Number of inner loops per outer diffusion step. For example, - with classifier-free guidance (CFG) you typically have 2 inner - loops: unconditional and conditional branches. - - taylor_factors_dtype (`torch.dtype`, defaults to `torch.float32`): + taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): Data type for computing Taylor series expansion factors. - - architecture (`str`, *optional*, defaults to `None`): - If provided, will look up default `cache` and `skip` regex - patterns in `_CACHE_TEMPLATES[architecture]`. These can be - overridden by `skip_identifiers` and `cache_identifiers`. + Use lower precision to reduce memory usage. + Use higher precision to improve numerical stability. skip_identifiers (`List[str]`, *optional*, defaults to `None`): Regex patterns (fullmatch) for module names to be placed in @@ -85,10 +66,12 @@ class TaylorSeerCacheConfig: Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode. + lite (`bool`, *optional*, defaults to `False`): + Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides + any user-provided `skip_identifiers` or `cache_identifiers` patterns. Notes: - Patterns are applied with `re.fullmatch` on `module_name`. - - If either `skip_identifiers` or `cache_identifiers` is provided - (or inferred from `architecture`), only modules matching at least + - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those patterns will be hooked. - If neither is provided, all attention-like modules will be hooked. """ @@ -97,11 +80,10 @@ class TaylorSeerCacheConfig: predict_steps: int = 5 stop_predicts: Optional[int] = None max_order: int = 1 - num_inner_loops: int = 1 - taylor_factors_dtype: torch.dtype = torch.float32 - architecture: str | None = None + taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16 skip_identifiers: Optional[List[str]] = None cache_identifiers: Optional[List[str]] = None + lite: bool = False def __repr__(self) -> str: return ( @@ -110,364 +92,153 @@ class TaylorSeerCacheConfig: f"predict_steps={self.predict_steps}, " f"stop_predicts={self.stop_predicts}, " f"max_order={self.max_order}, " - f"num_inner_loops={self.num_inner_loops}, " f"taylor_factors_dtype={self.taylor_factors_dtype}, " - f"architecture={self.architecture}, " f"skip_identifiers={self.skip_identifiers}, " - f"cache_identifiers={self.cache_identifiers})" + f"cache_identifiers={self.cache_identifiers}, " + f"lite={self.lite})" ) - @classmethod - def get_identifiers_template(cls) -> Dict[str, Dict[str, List[str]]]: - return _CACHE_TEMPLATES - - -class TaylorSeerOutputState: - """ - Manages the state for Taylor series-based prediction of a single attention output. - - Tracks Taylor expansion factors, last update step, and remaining prediction steps. - The Taylor expansion uses the (outer) timestep as the independent variable. - - This class is designed to handle state for a single inner loop index and a single - output (in cases where the module forward returns multiple tensors). - """ +class TaylorSeerState: def __init__( self, - module_name: str, - taylor_factors_dtype: torch.dtype, - module_dtype: torch.dtype, - is_skip: bool = False, + taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16, + max_order: int = 1, ): - self.module_name = module_name self.taylor_factors_dtype = taylor_factors_dtype - self.module_dtype = module_dtype - self.is_skip = is_skip + self.max_order = max_order - self.remaining_predictions: int = 0 + self.module_dtypes: Tuple[torch.dtype, ...] = () self.last_update_step: Optional[int] = None - self.taylor_factors: Dict[int, torch.Tensor] = {} + self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {} # For skip-mode modules - self.dummy_shape: Optional[Tuple[int, ...]] = None self.device: Optional[torch.device] = None - self.dummy_tensor: Optional[torch.Tensor] = None + self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None + + self.current_step = -1 def reset(self) -> None: - self.remaining_predictions = 0 self.last_update_step = None self.taylor_factors = {} - self.dummy_shape = None self.device = None - self.dummy_tensor = None + self.dummy_tensors = None + self.current_step = -1 def update( self, - features: torch.Tensor, + outputs: Tuple[torch.Tensor, ...], current_step: int, - max_order: int, - predict_steps: int, ) -> None: - """ - Update Taylor factors based on the current features and (outer) timestep. + self.module_dtypes = tuple(output.dtype for output in outputs) + for i in range(len(outputs)): + features = outputs[i].to(self.taylor_factors_dtype) + new_factors: Dict[int, torch.Tensor] = {0: features} + is_first_update = self.last_update_step is None + if not is_first_update: + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for TaylorSeer update.") - For non-skip modules, finite difference approximations for derivatives are - computed using recursive divided differences. - - Args: - features: Attention output features to update with. - current_step: Current outer timestep (true diffusion step). - max_order: Maximum Taylor expansion order. - predict_steps: Number of prediction steps to allow after this update. - """ - if self.is_skip: - # For skip modules we only need shape & device and a dummy tensor. - self.dummy_shape = features.shape - self.device = features.device - # zero is safer than uninitialized values for a "skipped" module - self.dummy_tensor = torch.zeros( - self.dummy_shape, - dtype=self.module_dtype, - device=self.device, - ) - self.taylor_factors = {} - self.last_update_step = current_step - self.remaining_predictions = predict_steps - return - - features = features.to(self.taylor_factors_dtype) - new_factors: Dict[int, torch.Tensor] = {0: features} - - is_first_update = self.last_update_step is None - - if not is_first_update: - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for TaylorSeer update.") - - # Recursive divided differences up to max_order - for i in range(max_order): - prev = self.taylor_factors.get(i) - if prev is None: - break - new_factors[i + 1] = (new_factors[i] - prev.to(self.taylor_factors_dtype)) / delta_step - - # Keep factors in taylor_factors_dtype - self.taylor_factors = new_factors + # Recursive divided differences up to max_order + for j in range(self.max_order): + prev = self.taylor_factors[i].get(j) + if prev is None: + break + new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step + self.taylor_factors[i] = new_factors self.last_update_step = current_step - self.remaining_predictions = predict_steps - - if self.module_name == "proj_out": - logger.debug( - "[UPDATE] module=%s remaining_predictions=%d current_step=%d is_first_update=%s", - self.module_name, - self.remaining_predictions, - current_step, - is_first_update, - ) def predict(self, current_step: int) -> torch.Tensor: - """ - Predict features using the Taylor series at the given (outer) timestep. - - Args: - current_step: Current outer timestep for prediction. - - Returns: - Predicted features in the module's dtype. - """ - if self.is_skip: - if self.dummy_tensor is None: - raise ValueError("Cannot predict for skip module without prior update.") - self.remaining_predictions -= 1 - return self.dummy_tensor - if self.last_update_step is None: raise ValueError("Cannot predict without prior initialization/update.") step_offset = current_step - self.last_update_step - output: torch.Tensor if not self.taylor_factors: raise ValueError("Taylor factors empty during prediction.") - # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) - output = torch.zeros_like(self.taylor_factors[0]) - for order, factor in self.taylor_factors.items(): - # Note: order starts at 0 - coeff = (step_offset**order) / math.factorial(order) - output = output + factor * coeff + outputs = [] + for i in range(len(self.module_dtypes)): + taylor_factors = self.taylor_factors[i] + # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) + output = torch.zeros_like(taylor_factors[0]) + for order, factor in taylor_factors.items(): + # Note: order starts at 0 + coeff = (step_offset**order) / math.factorial(order) + output = output + factor * coeff + outputs.append(output.to(self.module_dtypes[i])) - self.remaining_predictions -= 1 - out = output.to(self.module_dtype) - - if self.module_name == "proj_out": - logger.debug( - "[PREDICT] module=%s remaining_predictions=%d current_step=%d last_update_step=%s", - self.module_name, - self.remaining_predictions, - current_step, - self.last_update_step, - ) - - return out + return outputs -class TaylorSeerAttentionCacheHook(ModelHook): - """ - Hook for caching and predicting attention outputs using Taylor series approximations. - - Applies to attention modules in diffusion models (e.g., Flux). - Performs full computations during warmup, then alternates between blocks of - predictions and refreshes. - - The hook maintains separate states for each inner loop index (e.g., for - classifier-free guidance). Each inner loop has its own list of - `TaylorSeerOutputState` instances, one per output tensor from the module's - forward (typically one). - - The `step_counter` increments on every forward call of this module. - We define: - - `inner_index = step_counter % num_inner_loops` - - `true_step = step_counter // num_inner_loops` - - Warmup, prediction, and updates are handled per inner loop, but use the - shared `true_step` (outer diffusion step). - """ - +class TaylorSeerCacheHook(ModelHook): _is_stateful = True def __init__( self, module_name: str, predict_steps: int, - max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, - num_inner_loops: int = 1, + state_manager: StateManager, stop_predicts: Optional[int] = None, is_skip: bool = False, ): super().__init__() - if num_inner_loops <= 0: - raise ValueError("num_inner_loops must be >= 1") - self.module_name = module_name self.predict_steps = predict_steps - self.max_order = max_order self.warmup_steps = warmup_steps self.stop_predicts = stop_predicts - self.num_inner_loops = num_inner_loops self.taylor_factors_dtype = taylor_factors_dtype + self.state_manager = state_manager self.is_skip = is_skip - self.step_counter: int = -1 - self.states: Optional[List[Optional[List[TaylorSeerOutputState]]]] = None - self.num_outputs: Optional[int] = None + self.dummy_outputs = None def initialize_hook(self, module: torch.nn.Module): - self.step_counter = -1 - self.states = None - self.num_outputs = None return module def reset_state(self, module: torch.nn.Module) -> None: """ Reset state between sampling runs. """ - self.step_counter = -1 - self.states = None - self.num_outputs = None - - @staticmethod - def _listify(outputs): - if isinstance(outputs, torch.Tensor): - return [outputs] - return list(outputs) - - def _delistify(self, outputs_list): - if self.num_outputs == 1: - return outputs_list[0] - return tuple(outputs_list) - - def _ensure_states_initialized( - self, - module: torch.nn.Module, - inner_index: int, - true_step: int, - *args, - **kwargs, - ) -> Optional[List[torch.Tensor]]: - """ - Ensure per-inner-loop states exist. If this is the first call for this - inner_index, perform a full forward, initialize states, and return the - outputs. Otherwise, return None. - """ - if self.states is None: - self.states = [None for _ in range(self.num_inner_loops)] - - if self.states[inner_index] is not None: - return None - - if self.module_name == "proj_out": - logger.debug( - "[FIRST STEP] Initializing states for %s (inner_index=%d, true_step=%d)", - self.module_name, - inner_index, - true_step, - ) - - # First step for this inner loop: always full compute and initialize. - attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) - module_dtype = attention_outputs[0].dtype - - if self.num_outputs is None: - self.num_outputs = len(attention_outputs) - elif self.num_outputs != len(attention_outputs): - raise ValueError("Output count mismatch across inner loops.") - - self.states[inner_index] = [ - TaylorSeerOutputState( - self.module_name, - self.taylor_factors_dtype, - module_dtype, - is_skip=self.is_skip, - ) - for _ in range(self.num_outputs) - ] - - for i, features in enumerate(attention_outputs): - self.states[inner_index][i].update( - features=features, - current_step=true_step, - max_order=self.max_order, - predict_steps=self.predict_steps, - ) - - return attention_outputs + self.dummy_outputs = None + self.current_step = -1 + self.state_manager.reset() def new_forward(self, module: torch.nn.Module, *args, **kwargs): - self.step_counter += 1 - inner_index = self.step_counter % self.num_inner_loops - true_step = self.step_counter // self.num_inner_loops - is_warmup_phase = true_step < self.warmup_steps + state: TaylorSeerState = self.state_manager.get_state() + state.current_step += 1 + current_step = state.current_step + is_warmup_phase = current_step < self.warmup_steps + should_compute = ( + is_warmup_phase + or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0) + or (self.stop_predicts is not None and current_step >= self.stop_predicts) + ) + if should_compute: + outputs = self.fn_ref.original_forward(*args, **kwargs) + if not self.is_skip: + state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step) + else: + self.dummy_outputs = outputs + return outputs - if self.module_name == "proj_out": - logger.debug( - "[FORWARD] module=%s step_counter=%d inner_index=%d true_step=%d is_warmup=%s", - self.module_name, - self.step_counter, - inner_index, - true_step, - is_warmup_phase, - ) + if self.is_skip: + return self.dummy_outputs - # First-time initialization for this inner loop - maybe_outputs = self._ensure_states_initialized(module, inner_index, true_step, *args, **kwargs) - if maybe_outputs is not None: - return self._delistify(maybe_outputs) - - assert self.states is not None - states = self.states[inner_index] - assert states is not None and len(states) > 0 - - # If stop_predicts is set and we are past that step, always run full forward - if self.stop_predicts is not None and true_step >= self.stop_predicts: - attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) - return self._delistify(attention_outputs) - - # Decide between prediction vs refresh - # - Never predict during warmup. - # - Otherwise, predict while we still have remaining_predictions. - should_predict = (not is_warmup_phase) and (states[0].remaining_predictions > 0) - - if should_predict: - predicted_outputs = [state.predict(true_step) for state in states] - return self._delistify(predicted_outputs) - - # Full compute: warmup or refresh - attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs)) - for i, features in enumerate(attention_outputs): - states[i].update( - features=features, - current_step=true_step, - max_order=self.max_order, - predict_steps=self.predict_steps, - ) - return self._delistify(attention_outputs) + outputs = state.predict(current_step) + return outputs[0] if len(outputs) == 1 else outputs def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]: """ Resolve effective skip and cache pattern lists from config + templates. """ - template = _CACHE_TEMPLATES.get(config.architecture or "", {}) - default_skip = template.get("skip", []) - default_cache = template.get("cache", []) - skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else default_skip - cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else default_cache + skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None + cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None return skip_patterns or [], cache_patterns or [] @@ -496,8 +267,6 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi ... max_order=1, ... warmup_steps=3, ... taylor_factors_dtype=torch.float32, - ... architecture="flux", - ... num_inner_loops=2, # e.g. CFG ... ) >>> pipe.transformer.enable_cache(config) ``` @@ -507,67 +276,68 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi logger.debug("TaylorSeer skip identifiers: %s", skip_patterns) logger.debug("TaylorSeer cache identifiers: %s", cache_patterns) - use_patterns = bool(skip_patterns or cache_patterns) + cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS + + if config.lite: + logger.info("Using TaylorSeer Lite variant for cache.") + cache_patterns = _PROJ_OUT_IDENTIFIERS + skip_patterns = _BLOCK_IDENTIFIERS + if config.skip_identifiers or config.cache_identifiers: + logger.warning("Lite mode overrides user patterns.") for name, submodule in module.named_modules(): matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns) matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns) - - if use_patterns: - # If patterns are configured (either skip or cache), only touch modules - # that explicitly match at least one pattern. - if not (matches_skip or matches_cache): - continue - - logger.debug( - "Applying TaylorSeer cache to %s (mode=%s)", - name, - "skip" if matches_skip else "cache", - ) - _apply_taylorseer_cache_hook( - name=name, - module=submodule, - config=config, - is_skip=matches_skip, - ) - else: - # No patterns configured: fall back to "all attention modules". - if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - logger.debug("Applying TaylorSeer cache to %s (fallback attention mode)", name) - _apply_taylorseer_cache_hook( - name=name, - module=submodule, - config=config, - is_skip=False, - ) + if not (matches_skip or matches_cache): + continue + logger.debug( + "Applying TaylorSeer cache to %s (mode=%s)", + name, + "skip" if matches_skip else "cache", + ) + state_manager = StateManager( + TaylorSeerState, + init_kwargs={ + "taylor_factors_dtype": config.taylor_factors_dtype, + "max_order": config.max_order, + }, + ) + _apply_taylorseer_cache_hook( + name=name, + module=submodule, + config=config, + is_skip=matches_skip, + state_manager=state_manager, + ) def _apply_taylorseer_cache_hook( name: str, - module: Attention, + module: nn.Module, config: TaylorSeerCacheConfig, is_skip: bool, + state_manager: StateManager, ): """ - Registers the TaylorSeer hook on the specified attention module. + Registers the TaylorSeer hook on the specified nn.Module. Args: name: Name of the module. - module: The attention-like module to be hooked. + module: The nn.Module to be hooked. config: Cache configuration. is_skip: Whether this module should operate in "skip" mode. + state_manager: The state manager for managing hook state. """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = TaylorSeerAttentionCacheHook( + hook = TaylorSeerCacheHook( module_name=name, predict_steps=config.predict_steps, - max_order=config.max_order, warmup_steps=config.warmup_steps, taylor_factors_dtype=config.taylor_factors_dtype, - num_inner_loops=config.num_inner_loops, stop_predicts=config.stop_predicts, is_skip=is_skip, + state_manager=state_manager, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) + registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index ffbf296ff6..56de1e6463 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -97,7 +97,7 @@ class CacheMixin: from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK - from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK + from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -113,7 +113,7 @@ class CacheMixin: elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) elif isinstance(self._cache_config, TaylorSeerCacheConfig): - registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True) + registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")