diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 17d102f589..df37f9251c 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -29,64 +29,74 @@ class TaylorSeerCacheConfig: Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923 Attributes: - warmup_steps (`int`, defaults to `3`): - Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor - series factors are still updated, but no predictions are used. + cache_interval (`int`, defaults to `5`): + The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused + for this many subsequent denoising steps before refreshing with a new full forward pass. - predict_steps (`int`, defaults to `5`): - Number of prediction (cached) steps to take between two full computations. That is, once a module state is - refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will - be computed on the next step. + disable_cache_before_step (`int`, defaults to `3`): + The denoising step index before which caching is disabled, meaning full computation is performed for the initial + steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps, + Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step. - stop_predicts (`int`, *optional*, defaults to `None`): - Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts` - all modules are evaluated normally (no predictions, no state updates). + disable_cache_after_step (`int`, *optional*, defaults to `None`): + The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full + computations without predictions or state updates, ensuring accuracy in later stages if needed. max_order (`int`, defaults to `1`): - Maximum order of Taylor series expansion to approximate the features. Higher order gives closer - approximation but more compute. + The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better + approximations but increase computation and memory usage. taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use - higher precision to improve numerical stability. + Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect + stability; higher precision improves accuracy at the cost of more memory. - skip_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated - during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps. + inactive_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module + computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during + prediction steps to skip computation cheaply. - cache_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode. + active_identifiers (`List[str]`, *optional*, defaults to `None`): + Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs + are approximated and cached for reuse. + + use_lite_mode (`bool`, *optional*, defaults to `False`): + Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for + skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom + `inactive_identifiers` or `active_identifiers`. - lite (`bool`, *optional*, defaults to `False`): - Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided - `skip_identifiers` or `cache_identifiers` patterns. Notes: - - Patterns are applied with `re.fullmatch` on `module_name`. - - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those - patterns will be hooked. - - If neither is provided, all attention-like modules will be hooked. + - Patterns are matched using `re.fullmatch` on the module name. + - If `inactive_identifiers` or `active_identifiers` are provided, only matching modules are hooked. + - If neither is provided, all attention-like modules are hooked by default. + - Example of inactive and active usage: + ``` + def forward(x): + x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute + x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps + return x + ``` """ - warmup_steps: int = 3 - predict_steps: int = 5 - stop_predicts: Optional[int] = None + cache_interval: int = 5 + disable_cache_before_step: int = 3 + disable_cache_after_step: Optional[int] = None max_order: int = 1 taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16 - skip_identifiers: Optional[List[str]] = None - cache_identifiers: Optional[List[str]] = None - lite: bool = False + inactive_identifiers: Optional[List[str]] = None + active_identifiers: Optional[List[str]] = None + use_lite_mode: bool = False def __repr__(self) -> str: return ( "TaylorSeerCacheConfig(" - f"warmup_steps={self.warmup_steps}, " - f"predict_steps={self.predict_steps}, " - f"stop_predicts={self.stop_predicts}, " + f"cache_interval={self.cache_interval}, " + f"disable_cache_before_step={self.disable_cache_before_step}, " + f"disable_cache_after_step={self.disable_cache_after_step}, " f"max_order={self.max_order}, " f"taylor_factors_dtype={self.taylor_factors_dtype}, " - f"skip_identifiers={self.skip_identifiers}, " - f"cache_identifiers={self.cache_identifiers}, " - f"lite={self.lite})" + f"inactive_identifiers={self.inactive_identifiers}, " + f"active_identifiers={self.active_identifiers}, " + f"use_lite_mode={self.use_lite_mode})" ) @@ -95,70 +105,88 @@ class TaylorSeerState: self, taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16, max_order: int = 1, + is_inactive: bool = False, ): self.taylor_factors_dtype = taylor_factors_dtype self.max_order = max_order + self.is_inactive = is_inactive self.module_dtypes: Tuple[torch.dtype, ...] = () self.last_update_step: Optional[int] = None self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {} - - # For skip-mode modules + self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None self.device: Optional[torch.device] = None - self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None + self.current_step: int = -1 - self.current_step = -1 def reset(self) -> None: + self.current_step = -1 self.last_update_step = None self.taylor_factors = {} + self.inactive_shapes = None self.device = None - self.dummy_tensors = None - self.current_step = -1 def update( self, outputs: Tuple[torch.Tensor, ...], - current_step: int, ) -> None: self.module_dtypes = tuple(output.dtype for output in outputs) - for i in range(len(outputs)): - features = outputs[i].to(self.taylor_factors_dtype) - new_factors: Dict[int, torch.Tensor] = {0: features} - is_first_update = self.last_update_step is None - if not is_first_update: - delta_step = current_step - self.last_update_step - if delta_step == 0: - raise ValueError("Delta step cannot be zero for TaylorSeer update.") + self.device = outputs[0].device - # Recursive divided differences up to max_order - for j in range(self.max_order): - prev = self.taylor_factors[i].get(j) - if prev is None: - break - new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step - self.taylor_factors[i] = new_factors - self.last_update_step = current_step + if self.is_inactive: + self.inactive_shapes = tuple(output.shape for output in outputs) + else: + self.taylor_factors = {} + for i, output in enumerate(outputs): + features = output.to(self.taylor_factors_dtype) + new_factors: Dict[int, torch.Tensor] = {0: features} + is_first_update = self.last_update_step is None + if not is_first_update: + delta_step = self.current_step - self.last_update_step + if delta_step == 0: + raise ValueError("Delta step cannot be zero for TaylorSeer update.") - def predict(self, current_step: int) -> torch.Tensor: + # Recursive divided differences up to max_order + prev_factors = self.taylor_factors.get(i, {}) + for j in range(self.max_order): + prev = prev_factors.get(j) + if prev is None: + break + new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step + self.taylor_factors[i] = new_factors + + self.last_update_step = self.current_step + + def predict(self) -> List[torch.Tensor]: if self.last_update_step is None: raise ValueError("Cannot predict without prior initialization/update.") - step_offset = current_step - self.last_update_step - - if not self.taylor_factors: - raise ValueError("Taylor factors empty during prediction.") + step_offset = self.current_step - self.last_update_step outputs = [] - for i in range(len(self.module_dtypes)): - taylor_factors = self.taylor_factors[i] - # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) - output = torch.zeros_like(taylor_factors[0]) - for order, factor in taylor_factors.items(): - # Note: order starts at 0 - coeff = (step_offset**order) / math.factorial(order) - output = output + factor * coeff - outputs.append(output.to(self.module_dtypes[i])) + if self.is_inactive: + if self.inactive_shapes is None: + raise ValueError("Inactive shapes not set during prediction.") + for i in range(len(self.module_dtypes)): + outputs.append( + torch.zeros( + self.inactive_shapes[i], + dtype=self.module_dtypes[i], + device=self.device, + ) + ) + else: + if not self.taylor_factors: + raise ValueError("Taylor factors empty during prediction.") + for i in range(len(self.module_dtypes)): + taylor_factors = self.taylor_factors[i] + # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) + output = torch.zeros_like(taylor_factors[0]) + for order, factor in taylor_factors.items(): + # Note: order starts at 0 + coeff = (step_offset**order) / math.factorial(order) + output = output + factor * coeff + outputs.append(output.to(self.module_dtypes[i])) return outputs @@ -168,24 +196,18 @@ class TaylorSeerCacheHook(ModelHook): def __init__( self, - module_name: str, - predict_steps: int, - warmup_steps: int, + cache_interval: int, + disable_cache_before_step: int, taylor_factors_dtype: torch.dtype, state_manager: StateManager, - stop_predicts: Optional[int] = None, - is_skip: bool = False, + disable_cache_after_step: Optional[int] = None, ): super().__init__() - self.module_name = module_name - self.predict_steps = predict_steps - self.warmup_steps = warmup_steps - self.stop_predicts = stop_predicts + self.cache_interval = cache_interval + self.disable_cache_before_step = disable_cache_before_step + self.disable_cache_after_step = disable_cache_after_step self.taylor_factors_dtype = taylor_factors_dtype self.state_manager = state_manager - self.is_skip = is_skip - - self.dummy_outputs = None def initialize_hook(self, module: torch.nn.Module): return module @@ -194,50 +216,48 @@ class TaylorSeerCacheHook(ModelHook): """ Reset state between sampling runs. """ - self.dummy_outputs = None - self.current_step = -1 self.state_manager.reset() def new_forward(self, module: torch.nn.Module, *args, **kwargs): state: TaylorSeerState = self.state_manager.get_state() state.current_step += 1 current_step = state.current_step - is_warmup_phase = current_step < self.warmup_steps + is_warmup_phase = current_step < self.disable_cache_before_step + is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0) + is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step should_compute = ( is_warmup_phase - or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0) - or (self.stop_predicts is not None and current_step >= self.stop_predicts) + or is_compute_interval + or is_cooldown_phase ) if should_compute: outputs = self.fn_ref.original_forward(*args, **kwargs) - if not self.is_skip: - state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step) - else: - self.dummy_outputs = outputs + wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs + state.update(wrapped_outputs) return outputs - if self.is_skip: - return self.dummy_outputs - - outputs = state.predict(current_step) - return outputs[0] if len(outputs) == 1 else outputs + outputs_list = state.predict() + return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list) def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]: """ - Resolve effective skip and cache pattern lists from config + templates. + Resolve effective inactive and active pattern lists from config + templates. """ - skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None - cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None + inactive_patterns = config.inactive_identifiers if config.inactive_identifiers is not None else None + active_patterns = config.active_identifiers if config.active_identifiers is not None else None - return skip_patterns or [], cache_patterns or [] + return inactive_patterns or [], active_patterns or [] def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): """ Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet). + This function hooks selected modules in the model to enable caching or skipping based on the provided configuration, + reducing redundant computations in diffusion denoising loops. + Args: module (torch.nn.Module): The model subtree to apply the hooks to. config (TaylorSeerCacheConfig): Configuration for the cache. @@ -254,60 +274,41 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi >>> pipe.to("cuda") >>> config = TaylorSeerCacheConfig( - ... predict_steps=5, + ... cache_interval=5, ... max_order=1, - ... warmup_steps=3, + ... disable_cache_before_step=3, ... taylor_factors_dtype=torch.float32, ... ) >>> pipe.transformer.enable_cache(config) ``` """ - skip_patterns, cache_patterns = _resolve_patterns(config) + inactive_patterns, active_patterns = _resolve_patterns(config) - logger.debug("TaylorSeer skip identifiers: %s", skip_patterns) - logger.debug("TaylorSeer cache identifiers: %s", cache_patterns) + active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS - cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS - - if config.lite: + if config.use_lite_mode: logger.info("Using TaylorSeer Lite variant for cache.") - cache_patterns = _PROJ_OUT_IDENTIFIERS - skip_patterns = _BLOCK_IDENTIFIERS - if config.skip_identifiers or config.cache_identifiers: + active_patterns = _PROJ_OUT_IDENTIFIERS + inactive_patterns = _BLOCK_IDENTIFIERS + if config.inactive_identifiers or config.active_identifiers: logger.warning("Lite mode overrides user patterns.") for name, submodule in module.named_modules(): - matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns) - matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns) - if not (matches_skip or matches_cache): + matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns) + matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns) + if not (matches_inactive or matches_active): continue - logger.debug( - "Applying TaylorSeer cache to %s (mode=%s)", - name, - "skip" if matches_skip else "cache", - ) - state_manager = StateManager( - TaylorSeerState, - init_kwargs={ - "taylor_factors_dtype": config.taylor_factors_dtype, - "max_order": config.max_order, - }, - ) _apply_taylorseer_cache_hook( - name=name, module=submodule, config=config, - is_skip=matches_skip, - state_manager=state_manager, + is_inactive=matches_inactive, ) def _apply_taylorseer_cache_hook( - name: str, module: nn.Module, config: TaylorSeerCacheConfig, - is_skip: bool, - state_manager: StateManager, + is_inactive: bool, ): """ Registers the TaylorSeer hook on the specified nn.Module. @@ -316,19 +317,25 @@ def _apply_taylorseer_cache_hook( name: Name of the module. module: The nn.Module to be hooked. config: Cache configuration. - is_skip: Whether this module should operate in "skip" mode. - state_manager: The state manager for managing hook state. + is_inactive: Whether this module should operate in "inactive" mode. """ + state_manager = StateManager( + TaylorSeerState, + init_kwargs={ + "taylor_factors_dtype": config.taylor_factors_dtype, + "max_order": config.max_order, + "is_inactive": is_inactive, + }, + ) + registry = HookRegistry.check_if_exists_or_initialize(module) hook = TaylorSeerCacheHook( - module_name=name, - predict_steps=config.predict_steps, - warmup_steps=config.warmup_steps, + cache_interval=config.cache_interval, + disable_cache_before_step=config.disable_cache_before_step, taylor_factors_dtype=config.taylor_factors_dtype, - stop_predicts=config.stop_predicts, - is_skip=is_skip, + disable_cache_after_step=config.disable_cache_after_step, state_manager=state_manager, ) - registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) + registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) \ No newline at end of file