diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index df37f9251c..e1e0bebcf0 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -30,34 +30,35 @@ class TaylorSeerCacheConfig: Attributes: cache_interval (`int`, defaults to `5`): - The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused - for this many subsequent denoising steps before refreshing with a new full forward pass. + The interval between full computation steps. After a full computation, the cached (predicted) outputs are + reused for this many subsequent denoising steps before refreshing with a new full forward pass. disable_cache_before_step (`int`, defaults to `3`): - The denoising step index before which caching is disabled, meaning full computation is performed for the initial - steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps, - Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step. + The denoising step index before which caching is disabled, meaning full computation is performed for the + initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During + these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this + step. disable_cache_after_step (`int`, *optional*, defaults to `None`): - The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full - computations without predictions or state updates, ensuring accuracy in later stages if needed. + The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run + full computations without predictions or state updates, ensuring accuracy in later stages if needed. max_order (`int`, defaults to `1`): - The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better - approximations but increase computation and memory usage. + The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide + better approximations but increase computation and memory usage. taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect - stability; higher precision improves accuracy at the cost of more memory. + Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may + affect stability; higher precision improves accuracy at the cost of more memory. inactive_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module - computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during - prediction steps to skip computation cheaply. + Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the + module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) + during prediction steps to skip computation cheaply. active_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs - are approximated and cached for reuse. + Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where + outputs are approximated and cached for reuse. use_lite_mode (`bool`, *optional*, defaults to `False`): Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for @@ -118,7 +119,6 @@ class TaylorSeerState: self.device: Optional[torch.device] = None self.current_step: int = -1 - def reset(self) -> None: self.current_step = -1 self.last_update_step = None @@ -223,13 +223,9 @@ class TaylorSeerCacheHook(ModelHook): state.current_step += 1 current_step = state.current_step is_warmup_phase = current_step < self.disable_cache_before_step - is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0) + is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0 is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step - should_compute = ( - is_warmup_phase - or is_compute_interval - or is_cooldown_phase - ) + should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase if should_compute: outputs = self.fn_ref.original_forward(*args, **kwargs) wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs @@ -255,8 +251,8 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi """ Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet). - This function hooks selected modules in the model to enable caching or skipping based on the provided configuration, - reducing redundant computations in diffusion denoising loops. + This function hooks selected modules in the model to enable caching or skipping based on the provided + configuration, reducing redundant computations in diffusion denoising loops. Args: module (torch.nn.Module): The model subtree to apply the hooks to. @@ -338,4 +334,4 @@ def _apply_taylorseer_cache_hook( state_manager=state_manager, ) - registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK) \ No newline at end of file + registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)