1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix format & doc

This commit is contained in:
toilaluan
2025-11-25 06:02:13 +00:00
parent b3217139f5
commit 2be31f856e
4 changed files with 39 additions and 36 deletions

View File

@@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
### TaylorSeerCacheConfig
[[autodoc]] TaylorSeerCacheConfig
[[autodoc]] apply_taylorseer_cache

View File

@@ -25,4 +25,4 @@ if is_torch_available():
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

View File

@@ -1,13 +1,13 @@
import math
import re
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from .hooks import ModelHook, StateManager, HookRegistry
from ..utils import logging
from .hooks import HookRegistry, ModelHook, StateManager
logger = logging.get_logger(__name__)
@@ -19,60 +19,51 @@ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
)
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
_BLOCK_IDENTIFIERS = (
"^[^.]*block[^.]*\\.[^.]+$",
)
_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",)
_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
@dataclass
class TaylorSeerCacheConfig:
"""
Configuration for TaylorSeer cache.
See: https://huggingface.co/papers/2503.06923
Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
Attributes:
warmup_steps (`int`, defaults to `3`):
Number of denoising steps to run with full computation
before enabling caching. During warmup, the Taylor series factors
are still updated, but no predictions are used.
Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor
series factors are still updated, but no predictions are used.
predict_steps (`int`, defaults to `5`):
Number of prediction (cached) steps to take between two full
computations. That is, once a module state is refreshed, it will
be reused for `predict_steps` subsequent denoising steps, then a new
full forward will be computed on the next step.
Number of prediction (cached) steps to take between two full computations. That is, once a module state is
refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will
be computed on the next step.
stop_predicts (`int`, *optional*, defaults to `None`):
Denoising step index at which caching is disabled.
If provided, for `self.current_step >= stop_predicts` all modules are
evaluated normally (no predictions, no state updates).
Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts`
all modules are evaluated normally (no predictions, no state updates).
max_order (`int`, defaults to `1`):
Maximum order of Taylor series expansion to approximate the
features. Higher order gives closer approximation but more compute.
Maximum order of Taylor series expansion to approximate the features. Higher order gives closer
approximation but more compute.
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type for computing Taylor series expansion factors.
Use lower precision to reduce memory usage.
Use higher precision to improve numerical stability.
Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use
higher precision to improve numerical stability.
skip_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (fullmatch) for module names to be placed in
"skip" mode, where the module is evaluated during warmup /
refresh, but then replaced by a cheap dummy tensor during
prediction steps.
Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated
during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps.
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (fullmatch) for module names to be placed in
Taylor-series caching mode.
Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode.
lite (`bool`, *optional*, defaults to `False`):
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides
any user-provided `skip_identifiers` or `cache_identifiers` patterns.
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided
`skip_identifiers` or `cache_identifiers` patterns.
Notes:
- Patterns are applied with `re.fullmatch` on `module_name`.
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least
one of those patterns will be hooked.
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those
patterns will be hooked.
- If neither is provided, all attention-like modules will be hooked.
"""
@@ -255,13 +246,13 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
```python
>>> import torch
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig
>>>
>>> pipe = FluxPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-dev",
... torch_dtype=torch.bfloat16,
... )
>>> pipe.to("cuda")
>>>
>>> config = TaylorSeerCacheConfig(
... predict_steps=5,
... max_order=1,

View File

@@ -93,7 +93,13 @@ class CacheMixin:
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
)
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK