From 1099e493e635526c8ecbc8ebca0f57e4bea2a0d8 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 14 Nov 2025 07:00:12 +0000 Subject: [PATCH] refractor, add docs --- src/diffusers/__init__.py | 4 + src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/taylorseer_cache.py | 240 +++++++++++++++++------- src/diffusers/models/cache_utils.py | 9 +- 4 files changed, 182 insertions(+), 72 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 02df34c07e..69d4aa4ba3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -169,10 +169,12 @@ else: "LayerSkipConfig", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", + "TaylorSeerCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", "apply_pyramid_attention_broadcast", + "apply_taylorseer_cache", ] ) _import_structure["models"].extend( @@ -883,10 +885,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: LayerSkipConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) from .models import ( AllegroTransformer3DModel, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea99..1d9d43d96b 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -25,3 +25,4 @@ if is_torch_available(): from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig + from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache \ No newline at end of file diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 6c99f095e2..509f6ba117 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,9 +1,6 @@ -# Experimental hook for TaylorSeer cache -# Supports Flux only for now - import torch from dataclasses import dataclass -from typing import Callable +from typing import Callable, Optional, List, Dict from .hooks import ModelHook import math from ..models.attention import Attention @@ -13,118 +10,219 @@ from ._common import ( ) from ..hooks import HookRegistry from ..utils import logging + logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" @dataclass class TaylorSeerCacheConfig: - warmup_steps: int = 3 # full compute some first steps - fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed - max_order: int = 1 # order of Taylor series expansion - current_timestep_callback: Callable[[], int] = None + """ + Configuration for TaylorSeer cache. + See: https://huggingface.co/papers/2503.06923 -class TaylorSeerState: - def __init__(self): - self.predict_counter: int = 0 - self.last_step: int = 1000 - self.taylor_factors: dict[int, torch.Tensor] = {} + Attributes: + warmup_steps (int, defaults to 3): Number of warmup steps without caching. + predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. + max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. + taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. + """ + warmup_steps: int = 3 + predict_steps: int = 5 + max_order: int = 1 + taylor_factors_dtype: torch.dtype = torch.float32 + + def __repr__(self) -> str: + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})" + +class TaylorSeerOutputState: + """ + Manages the state for Taylor series-based prediction of a single attention output. + Tracks Taylor expansion factors, last update step, and remaining prediction steps. + The Taylor expansion uses the timestep as the independent variable for approximation. + """ + + def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype): + self.module_name = module_name + self.remaining_predictions: int = 0 + self.last_update_step: Optional[int] = None + self.taylor_factors: Dict[int, torch.Tensor] = {} + self.taylor_factors_dtype = taylor_factors_dtype + self.module_dtype = module_dtype def reset(self): - self.predict_counter = 0 - self.last_step = 1000 + self.remaining_predictions = 0 + self.last_update_step = None self.taylor_factors = {} - def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int): - logger.debug("="*10) - N = self.last_step - current_step - logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}") - # initialize the first order taylor factors - new_taylor_factors = {0: features} - for i in range(max_order): - if (self.taylor_factors.get(i) is not None) and current_step > 1: - new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N - else: - break - self.taylor_factors = new_taylor_factors - self.last_step = current_step - self.predict_counter = refresh_threshold - logger.debug(f"last_step: {self.last_step}") - logger.debug(f"predict_counter: {self.predict_counter}") - logger.debug("="*10) + def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool): + """ + Updates the Taylor factors based on the current features and timestep. + Computes finite difference approximations for derivatives using recursive divided differences. - def predict(self, current_step: int): - k = current_step - self.last_step + Args: + features (torch.Tensor): The attention output features to update with. + current_step (int): The current timestep or step number from the diffusion model. + max_order (int): Maximum order of the Taylor expansion. + predict_steps (int): Number of prediction steps to set after update. + is_first_update (bool): Whether this is the initial update (skips difference computation). + """ + features = features.to(self.taylor_factors_dtype) + new_factors = {0: features} + if not is_first_update: + if self.last_update_step is None: + raise ValueError("Cannot update without prior initialization.") + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for updates.") + for i in range(max_order): + if i in self.taylor_factors: + # Finite difference: (current - previous) / delta for forward approximation + new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step + + # taylor factors will be kept in the taylor_factors_dtype + self.taylor_factors = new_factors + self.last_update_step = current_step + self.remaining_predictions = predict_steps + + def predict(self, current_step: int) -> torch.Tensor: + """ + Predicts the features using the Taylor series expansion at the given timestep. + + Args: + current_step (int): The current timestep for prediction. + + Returns: + torch.Tensor: The predicted features in the module's dtype. + """ + if self.last_update_step is None: + raise ValueError("Cannot predict without prior update.") + step_offset = current_step - self.last_update_step device = self.taylor_factors[0].device - output = torch.zeros_like(self.taylor_factors[0], device=device) - for i in range(len(self.taylor_factors)): - output += self.taylor_factors[i] * (k ** i) / math.factorial(i) - self.predict_counter -= 1 - logger.debug(f"predict_counter: {self.predict_counter}") - logger.debug(f"k: {k}") - return output + output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype) + for order in range(len(self.taylor_factors)): + output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order) + self.remaining_predictions -= 1 + # output will be converted to the module's dtype + return output.to(self.module_dtype) class TaylorSeerAttentionCacheHook(ModelHook): + """ + Hook for caching and predicting attention outputs using Taylor series approximations. + Applies to attention modules in diffusion models (e.g., Flux). + Performs full computations during warmup, then alternates between predictions and refreshes. + """ _is_stateful = True - def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int): + def __init__( + self, + module_name: str, + predict_steps: int, + max_order: int, + warmup_steps: int, + taylor_factors_dtype: torch.dtype, + module_dtype: torch.dtype = None, + ): super().__init__() - self.fresh_threshold = fresh_threshold + self.module_name = module_name + self.predict_steps = predict_steps self.max_order = max_order - self.current_timestep_callback = current_timestep_callback self.warmup_steps = warmup_steps + self.step_counter = -1 + self.states: Optional[List[TaylorSeerOutputState]] = None + self.num_outputs: Optional[int] = None + self.taylor_factors_dtype = taylor_factors_dtype + self.module_dtype = module_dtype - def initialize_hook(self, module): + def initialize_hook(self, module: torch.nn.Module): + self.step_counter = -1 self.states = None self.num_outputs = None - self.warmup_steps_counter = 0 + self.module_dtype = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): - current_step = self.current_timestep_callback() - assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook" + self.step_counter += 1 + is_warmup_phase = self.step_counter < self.warmup_steps if self.states is None: + # First step: always full compute and initialize attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] + else: + attention_outputs = list(attention_outputs) + module_dtype = attention_outputs[0].dtype self.num_outputs = len(attention_outputs) - self.states = [TaylorSeerState() for _ in range(self.num_outputs)] - for i, feat in enumerate(attention_outputs): - self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs + self.states = [ + TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype) + for _ in range(self.num_outputs) + ] + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) - should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps - - if not should_predict: + should_predict = self.states[0].remaining_predictions > 0 + if is_warmup_phase or not should_predict: + # Full compute during warmup or when refresh needed attention_outputs = self.fn_ref.original_forward(*args, **kwargs) if isinstance(attention_outputs, torch.Tensor): attention_outputs = [attention_outputs] - for i, feat in enumerate(attention_outputs): - self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold) - return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs + else: + attention_outputs = list(attention_outputs) + is_first_update = self.step_counter == 0 # Only True for the very first step + for i, features in enumerate(attention_outputs): + self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update) + return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) else: - predicted_outputs = [state.predict(current_step) for state in self.states] - return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs + # Predict using Taylor series + predicted_outputs = [state.predict(self.step_counter) for state in self.states] + return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: if self.states is not None: for state in self.states: state.reset() - return module def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): + """ + Applies the TaylorSeer cache to given pipeline. + + Args: + module (torch.nn.Module): The model to apply the hook to. + config (TaylorSeerCacheConfig): Configuration for the cache. + + Example: + ```python + >>> import torch + >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32) + >>> apply_taylorseer_cache(pipe.transformer, config) + ``` + """ for name, submodule in module.named_modules(): - if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): - continue - logger.debug(f"Applying TaylorSeer cache to {name}") - _apply_taylorseer_cache_on_attention_class(name, submodule, config) + if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + logger.debug(f"Applying TaylorSeer cache to {name}") + _apply_taylorseer_cache_hook(name, submodule, config) +def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig): + """ + Registers the TaylorSeer hook on the specified attention module. -def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig): - _apply_taylorseer_cache_hook(module, config) - - -def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig): + Args: + name (str): Name of the module. + module (Attention): The attention module. + config (TaylorSeerCacheConfig): Configuration for the cache. + """ registry = HookRegistry.check_if_exists_or_initialize(module) - hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps) + hook = TaylorSeerAttentionCacheHook( + name, + config.predict_steps, + config.max_order, + config.warmup_steps, + config.taylor_factors_dtype, + ) registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 605c0d588c..ffbf296ff6 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -67,9 +67,11 @@ class CacheMixin: FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_pyramid_attention_broadcast, + apply_taylorseer_cache, ) if self.is_cache_enabled: @@ -83,16 +85,19 @@ class CacheMixin: apply_first_block_cache(self, config) elif isinstance(config, PyramidAttentionBroadcastConfig): apply_pyramid_attention_broadcast(self, config) + elif isinstance(config, TaylorSeerCacheConfig): + apply_taylorseer_cache(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK + from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") @@ -107,6 +112,8 @@ class CacheMixin: registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) + elif isinstance(self._cache_config, TaylorSeerCacheConfig): + registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")