mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix format & doc
This commit is contained in:
@@ -34,3 +34,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] apply_first_block_cache
|
||||
|
||||
### TaylorSeerCacheConfig
|
||||
|
||||
[[autodoc]] TaylorSeerCacheConfig
|
||||
|
||||
[[autodoc]] apply_taylorseer_cache
|
||||
|
||||
@@ -25,4 +25,4 @@ if is_torch_available():
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .hooks import ModelHook, StateManager, HookRegistry
|
||||
from ..utils import logging
|
||||
from .hooks import HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -19,60 +19,51 @@ _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
||||
)
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
||||
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
_BLOCK_IDENTIFIERS = (
|
||||
"^[^.]*block[^.]*\\.[^.]+$",
|
||||
)
|
||||
_BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",)
|
||||
_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaylorSeerCacheConfig:
|
||||
"""
|
||||
Configuration for TaylorSeer cache.
|
||||
See: https://huggingface.co/papers/2503.06923
|
||||
Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
|
||||
|
||||
Attributes:
|
||||
warmup_steps (`int`, defaults to `3`):
|
||||
Number of denoising steps to run with full computation
|
||||
before enabling caching. During warmup, the Taylor series factors
|
||||
are still updated, but no predictions are used.
|
||||
Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor
|
||||
series factors are still updated, but no predictions are used.
|
||||
|
||||
predict_steps (`int`, defaults to `5`):
|
||||
Number of prediction (cached) steps to take between two full
|
||||
computations. That is, once a module state is refreshed, it will
|
||||
be reused for `predict_steps` subsequent denoising steps, then a new
|
||||
full forward will be computed on the next step.
|
||||
Number of prediction (cached) steps to take between two full computations. That is, once a module state is
|
||||
refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will
|
||||
be computed on the next step.
|
||||
|
||||
stop_predicts (`int`, *optional*, defaults to `None`):
|
||||
Denoising step index at which caching is disabled.
|
||||
If provided, for `self.current_step >= stop_predicts` all modules are
|
||||
evaluated normally (no predictions, no state updates).
|
||||
Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts`
|
||||
all modules are evaluated normally (no predictions, no state updates).
|
||||
|
||||
max_order (`int`, defaults to `1`):
|
||||
Maximum order of Taylor series expansion to approximate the
|
||||
features. Higher order gives closer approximation but more compute.
|
||||
Maximum order of Taylor series expansion to approximate the features. Higher order gives closer
|
||||
approximation but more compute.
|
||||
|
||||
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
||||
Data type for computing Taylor series expansion factors.
|
||||
Use lower precision to reduce memory usage.
|
||||
Use higher precision to improve numerical stability.
|
||||
Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use
|
||||
higher precision to improve numerical stability.
|
||||
|
||||
skip_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (fullmatch) for module names to be placed in
|
||||
"skip" mode, where the module is evaluated during warmup /
|
||||
refresh, but then replaced by a cheap dummy tensor during
|
||||
prediction steps.
|
||||
Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated
|
||||
during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps.
|
||||
|
||||
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (fullmatch) for module names to be placed in
|
||||
Taylor-series caching mode.
|
||||
Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode.
|
||||
|
||||
lite (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides
|
||||
any user-provided `skip_identifiers` or `cache_identifiers` patterns.
|
||||
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided
|
||||
`skip_identifiers` or `cache_identifiers` patterns.
|
||||
Notes:
|
||||
- Patterns are applied with `re.fullmatch` on `module_name`.
|
||||
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least
|
||||
one of those patterns will be hooked.
|
||||
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those
|
||||
patterns will be hooked.
|
||||
- If neither is provided, all attention-like modules will be hooked.
|
||||
"""
|
||||
|
||||
@@ -255,13 +246,13 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig
|
||||
>>>
|
||||
|
||||
>>> pipe = FluxPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-dev",
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>>
|
||||
|
||||
>>> config = TaylorSeerCacheConfig(
|
||||
... predict_steps=5,
|
||||
... max_order=1,
|
||||
|
||||
@@ -93,7 +93,13 @@ class CacheMixin:
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
Reference in New Issue
Block a user