From 7b4ad2de63c314489b8129a496bea5c67e31cf7e Mon Sep 17 00:00:00 2001 From: toilaluan Date: Fri, 14 Nov 2025 09:09:46 +0000 Subject: [PATCH] add configurable cache, skip compute module --- src/diffusers/hooks/taylorseer_cache.py | 167 ++++++++++++++++++------ 1 file changed, 125 insertions(+), 42 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 509f6ba117..89d6da3074 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -10,10 +10,28 @@ from ._common import ( ) from ..hooks import HookRegistry from ..utils import logging - +import re +from collections import defaultdict logger = logging.get_logger(__name__) # pylint: disable=invalid-name _TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache" +SPECIAL_CACHE_IDENTIFIERS = { + "flux": [ + r"transformer_blocks\.\d+\.attn", + r"transformer_blocks\.\d+\.ff", + r"transformer_blocks\.\d+\.ff_context", + r"single_transformer_blocks\.\d+\.proj_out", + ] +} +SKIP_COMPUTE_IDENTIFIERS = { + "flux": [ + r"single_transformer_blocks\.\d+\.attn", + r"single_transformer_blocks\.\d+\.proj_mlp", + r"single_transformer_blocks\.\d+\.act_mlp", + ] +} + + @dataclass class TaylorSeerCacheConfig: """ @@ -25,14 +43,22 @@ class TaylorSeerCacheConfig: predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps. max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features. taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors. + architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers. + skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation. + special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache. """ + warmup_steps: int = 3 predict_steps: int = 5 max_order: int = 1 taylor_factors_dtype: torch.dtype = torch.float32 + architecture: str | None = None + skip_compute_identifiers: List[str] = None + special_cache_identifiers: List[str] = None def __repr__(self) -> str: - return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})" + return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})" + class TaylorSeerOutputState: """ @@ -41,20 +67,31 @@ class TaylorSeerOutputState: The Taylor expansion uses the timestep as the independent variable for approximation. """ - def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype): + def __init__( + self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype, is_skip: bool = False + ): self.module_name = module_name self.remaining_predictions: int = 0 self.last_update_step: Optional[int] = None self.taylor_factors: Dict[int, torch.Tensor] = {} self.taylor_factors_dtype = taylor_factors_dtype self.module_dtype = module_dtype + self.is_skip = is_skip + self.dummy_shape: Optional[Tuple[int, ...]] = None + self.device: Optional[torch.device] = None + self.dummy_tensor: Optional[torch.Tensor] = None def reset(self): self.remaining_predictions = 0 self.last_update_step = None self.taylor_factors = {} + self.dummy_shape = None + self.device = None + self.dummy_tensor = None - def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool): + def update( + self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool + ): """ Updates the Taylor factors based on the current features and timestep. Computes finite difference approximations for derivatives using recursive divided differences. @@ -66,23 +103,33 @@ class TaylorSeerOutputState: predict_steps (int): Number of prediction steps to set after update. is_first_update (bool): Whether this is the initial update (skips difference computation). """ - features = features.to(self.taylor_factors_dtype) - new_factors = {0: features} - if not is_first_update: - if self.last_update_step is None: - raise ValueError("Cannot update without prior initialization.") - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for updates.") - for i in range(max_order): - if i in self.taylor_factors: - # Finite difference: (current - previous) / delta for forward approximation - new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step + if self.is_skip: + self.dummy_shape = features.shape + self.device = features.device + self.taylor_factors = {} + self.last_update_step = current_step + self.remaining_predictions = predict_steps + else: + features = features.to(self.taylor_factors_dtype) + new_factors = {0: features} + if not is_first_update: + if self.last_update_step is None: + raise ValueError("Cannot update without prior initialization.") + delta_step = current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for updates.") + for i in range(max_order): + if i in self.taylor_factors: + new_factors[i + 1] = ( + new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype) + ) / delta_step + else: + break - # taylor factors will be kept in the taylor_factors_dtype - self.taylor_factors = new_factors - self.last_update_step = current_step - self.remaining_predictions = predict_steps + # taylor factors will be kept in the taylor_factors_dtype + self.taylor_factors = new_factors + self.last_update_step = current_step + self.remaining_predictions = predict_steps def predict(self, current_step: int) -> torch.Tensor: """ @@ -94,16 +141,22 @@ class TaylorSeerOutputState: Returns: torch.Tensor: The predicted features in the module's dtype. """ - if self.last_update_step is None: - raise ValueError("Cannot predict without prior update.") - step_offset = current_step - self.last_update_step - device = self.taylor_factors[0].device - output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype) - for order in range(len(self.taylor_factors)): - output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order) - self.remaining_predictions -= 1 - # output will be converted to the module's dtype - return output.to(self.module_dtype) + if self.is_skip: + if self.dummy_shape is None or self.device is None: + raise ValueError("Cannot predict for skip module without prior update.") + self.remaining_predictions -= 1 + return torch.empty(self.dummy_shape, dtype=self.module_dtype, device=self.device) + else: + if self.last_update_step is None: + raise ValueError("Cannot predict without prior update.") + step_offset = current_step - self.last_update_step + output = 0 + for order in range(len(self.taylor_factors)): + output += self.taylor_factors[order] * (step_offset**order) * (1 / math.factorial(order)) + self.remaining_predictions -= 1 + # output will be converted to the module's dtype + return output.to(self.module_dtype) + class TaylorSeerAttentionCacheHook(ModelHook): """ @@ -111,6 +164,7 @@ class TaylorSeerAttentionCacheHook(ModelHook): Applies to attention modules in diffusion models (e.g., Flux). Performs full computations during warmup, then alternates between predictions and refreshes. """ + _is_stateful = True def __init__( @@ -120,7 +174,7 @@ class TaylorSeerAttentionCacheHook(ModelHook): max_order: int, warmup_steps: int, taylor_factors_dtype: torch.dtype, - module_dtype: torch.dtype = None, + is_skip_compute: bool = False, ): super().__init__() self.module_name = module_name @@ -131,13 +185,12 @@ class TaylorSeerAttentionCacheHook(ModelHook): self.states: Optional[List[TaylorSeerOutputState]] = None self.num_outputs: Optional[int] = None self.taylor_factors_dtype = taylor_factors_dtype - self.module_dtype = module_dtype + self.is_skip_compute = is_skip_compute def initialize_hook(self, module: torch.nn.Module): self.step_counter = -1 self.states = None self.num_outputs = None - self.module_dtype = None return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): @@ -154,11 +207,15 @@ class TaylorSeerAttentionCacheHook(ModelHook): module_dtype = attention_outputs[0].dtype self.num_outputs = len(attention_outputs) self.states = [ - TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype) + TaylorSeerOutputState( + self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute + ) for _ in range(self.num_outputs) ] for i, features in enumerate(attention_outputs): - self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True) + self.states[i].update( + features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True + ) return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs) should_predict = self.states[0].remaining_predictions > 0 @@ -179,9 +236,8 @@ class TaylorSeerAttentionCacheHook(ModelHook): return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs) def reset_state(self, module: torch.nn.Module) -> None: - if self.states is not None: - for state in self.states: - state.reset() + self.states = None + def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): """ @@ -199,30 +255,57 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32) + >>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux") >>> apply_taylorseer_cache(pipe.transformer, config) ``` """ + if config.skip_compute_identifiers: + skip_compute_identifiers = config.skip_compute_identifiers + else: + skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + + if config.special_cache_identifiers: + special_cache_identifiers = config.special_cache_identifiers + else: + special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, []) + + logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}") + logger.debug(f"Special cache identifiers: {special_cache_identifiers}") + for name, submodule in module.named_modules(): - if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): + if skip_compute_identifiers and special_cache_identifiers: + if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any( + re.fullmatch(identifier, name) for identifier in special_cache_identifiers + ): + logger.debug(f"Applying TaylorSeer cache to {name}") + _apply_taylorseer_cache_hook(name, submodule, config) + elif isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)): logger.debug(f"Applying TaylorSeer cache to {name}") _apply_taylorseer_cache_hook(name, submodule, config) + def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig): """ Registers the TaylorSeer hook on the specified attention module. - Args: name (str): Name of the module. module (Attention): The attention module. config (TaylorSeerCacheConfig): Configuration for the cache. """ + + is_skip_compute = any( + re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, []) + ) + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = TaylorSeerAttentionCacheHook( name, config.predict_steps, config.max_order, config.warmup_steps, config.taylor_factors_dtype, + is_skip_compute=is_skip_compute, ) - registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK) \ No newline at end of file + + registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)