From 2be31f856e628fde5dfd8f96fa04401dca2397fa Mon Sep 17 00:00:00 2001 From: toilaluan Date: Tue, 25 Nov 2025 06:02:13 +0000 Subject: [PATCH] fix format & doc --- docs/source/en/api/cache.md | 6 +++ src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/taylorseer_cache.py | 59 +++++++++++-------------- src/diffusers/models/cache_utils.py | 8 +++- 4 files changed, 39 insertions(+), 36 deletions(-) diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index 9ba4742085..c93dcad438 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate [[autodoc]] FirstBlockCacheConfig [[autodoc]] apply_first_block_cache + +### TaylorSeerCacheConfig + +[[autodoc]] TaylorSeerCacheConfig + +[[autodoc]] apply_taylorseer_cache diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 1d9d43d96b..eb12b8a52a 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -25,4 +25,4 @@ if is_torch_available(): from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig - from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache \ No newline at end of file + from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index f400576fed..17d102f589 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -1,13 +1,13 @@ import math import re from dataclasses import dataclass -from typing import Optional, List, Dict, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from .hooks import ModelHook, StateManager, HookRegistry from ..utils import logging +from .hooks import HookRegistry, ModelHook, StateManager logger = logging.get_logger(__name__) @@ -19,60 +19,51 @@ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( ) _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS -_BLOCK_IDENTIFIERS = ( - "^[^.]*block[^.]*\\.[^.]+$", -) +_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",) _PROJ_OUT_IDENTIFIERS = ("^proj_out$",) + @dataclass class TaylorSeerCacheConfig: """ - Configuration for TaylorSeer cache. - See: https://huggingface.co/papers/2503.06923 + Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923 Attributes: warmup_steps (`int`, defaults to `3`): - Number of denoising steps to run with full computation - before enabling caching. During warmup, the Taylor series factors - are still updated, but no predictions are used. + Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor + series factors are still updated, but no predictions are used. predict_steps (`int`, defaults to `5`): - Number of prediction (cached) steps to take between two full - computations. That is, once a module state is refreshed, it will - be reused for `predict_steps` subsequent denoising steps, then a new - full forward will be computed on the next step. + Number of prediction (cached) steps to take between two full computations. That is, once a module state is + refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will + be computed on the next step. stop_predicts (`int`, *optional*, defaults to `None`): - Denoising step index at which caching is disabled. - If provided, for `self.current_step >= stop_predicts` all modules are - evaluated normally (no predictions, no state updates). + Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts` + all modules are evaluated normally (no predictions, no state updates). max_order (`int`, defaults to `1`): - Maximum order of Taylor series expansion to approximate the - features. Higher order gives closer approximation but more compute. + Maximum order of Taylor series expansion to approximate the features. Higher order gives closer + approximation but more compute. taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): - Data type for computing Taylor series expansion factors. - Use lower precision to reduce memory usage. - Use higher precision to improve numerical stability. + Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use + higher precision to improve numerical stability. skip_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in - "skip" mode, where the module is evaluated during warmup / - refresh, but then replaced by a cheap dummy tensor during - prediction steps. + Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated + during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps. cache_identifiers (`List[str]`, *optional*, defaults to `None`): - Regex patterns (fullmatch) for module names to be placed in - Taylor-series caching mode. + Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode. lite (`bool`, *optional*, defaults to `False`): - Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides - any user-provided `skip_identifiers` or `cache_identifiers` patterns. + Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided + `skip_identifiers` or `cache_identifiers` patterns. Notes: - Patterns are applied with `re.fullmatch` on `module_name`. - - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least - one of those patterns will be hooked. + - If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those + patterns will be hooked. - If neither is provided, all attention-like modules will be hooked. """ @@ -255,13 +246,13 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi ```python >>> import torch >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig - >>> + >>> pipe = FluxPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ... torch_dtype=torch.bfloat16, ... ) >>> pipe.to("cuda") - >>> + >>> config = TaylorSeerCacheConfig( ... predict_steps=5, ... max_order=1, diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 56de1e6463..f4ad1af278 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -93,7 +93,13 @@ class CacheMixin: self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig + from ..hooks import ( + FasterCacheConfig, + FirstBlockCacheConfig, + HookRegistry, + PyramidAttentionBroadcastConfig, + TaylorSeerCacheConfig, + ) from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK