From 4fb3f53b6c49da08d8d0171c8fe120daf81edcc1 Mon Sep 17 00:00:00 2001 From: toilaluan Date: Wed, 3 Dec 2025 10:43:01 +0000 Subject: [PATCH] rename identifiers, use more expressive taylor predict loop --- src/diffusers/hooks/taylorseer_cache.py | 31 ++++++++++++------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 4b9f4dd479..607d652f4a 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -51,12 +51,12 @@ class TaylorSeerCacheConfig: Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect stability; higher precision improves accuracy at the cost of more memory. - inactive_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the + skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode, the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during prediction steps to skip computation cheaply. - active_identifiers (`List[str]`, *optional*, defaults to `None`): + cache_identifiers (`List[str]`, *optional*, defaults to `None`): Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs are approximated and cached for reuse. @@ -83,8 +83,8 @@ class TaylorSeerCacheConfig: disable_cache_after_step: Optional[int] = None max_order: int = 1 taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16 - inactive_identifiers: Optional[List[str]] = None - active_identifiers: Optional[List[str]] = None + skip_predict_identifiers: Optional[List[str]] = None + cache_identifiers: Optional[List[str]] = None use_lite_mode: bool = False def __repr__(self) -> str: @@ -95,8 +95,8 @@ class TaylorSeerCacheConfig: f"disable_cache_after_step={self.disable_cache_after_step}, " f"max_order={self.max_order}, " f"taylor_factors_dtype={self.taylor_factors_dtype}, " - f"inactive_identifiers={self.inactive_identifiers}, " - f"active_identifiers={self.active_identifiers}, " + f"skip_predict_identifiers={self.skip_predict_identifiers}, " + f"cache_identifiers={self.cache_identifiers}, " f"use_lite_mode={self.use_lite_mode})" ) @@ -136,7 +136,6 @@ class TaylorSeerState: if self.is_inactive: self.inactive_shapes = tuple(output.shape for output in outputs) else: - self.taylor_factors = {} for i, features in enumerate(outputs): new_factors: Dict[int, torch.Tensor] = {0: features} is_first_update = self.last_update_step is None @@ -179,17 +178,17 @@ class TaylorSeerState: else: if not self.taylor_factors: raise ValueError("Taylor factors empty during prediction.") - for i in range(len(self.module_dtypes)): + num_outputs = len(self.taylor_factors) + num_orders = len(self.taylor_factors[0]) + for i in range(num_outputs): output_dtype = self.module_dtypes[i] taylor_factors = self.taylor_factors[i] - # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) output = torch.zeros_like(taylor_factors[0], dtype=output_dtype) - for order, factor in taylor_factors.items(): - # Note: order starts at 0 + for order in range(num_orders): coeff = (step_offset**order) / math.factorial(order) + factor = taylor_factors[order] output = output + factor.to(output_dtype) * coeff outputs.append(output) - return outputs @@ -243,8 +242,8 @@ def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[st Resolve effective inactive and active pattern lists from config + templates. """ - inactive_patterns = config.inactive_identifiers if config.inactive_identifiers is not None else None - active_patterns = config.active_identifiers if config.active_identifiers is not None else None + inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None + active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None return inactive_patterns or [], active_patterns or [] @@ -288,7 +287,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi logger.info("Using TaylorSeer Lite variant for cache.") active_patterns = _PROJ_OUT_IDENTIFIERS inactive_patterns = _BLOCK_IDENTIFIERS - if config.inactive_identifiers or config.active_identifiers: + if config.skip_predict_identifiers or config.cache_identifiers: logger.warning("Lite mode overrides user patterns.") for name, submodule in module.named_modules():