1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

refractor to use state manager

This commit is contained in:
toilaluan
2025-11-25 05:28:00 +00:00
parent 9083e1eba5
commit a8ea383044
2 changed files with 136 additions and 366 deletions

View File

@@ -4,36 +4,25 @@ from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple
import torch
import torch.nn as nn
from .hooks import ModelHook
from ..models.attention import Attention
from ..models.attention import AttentionModuleMixin
from ._common import _ATTENTION_CLASSES
from ..hooks import HookRegistry
from .hooks import ModelHook, StateManager, HookRegistry
from ..utils import logging
logger = logging.get_logger(__name__)
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
# Predefined cache templates for optimized architectures
_CACHE_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
"flux": {
"cache": [
r"transformer_blocks\.\d+\.attn",
r"transformer_blocks\.\d+\.ff",
r"transformer_blocks\.\d+\.ff_context",
r"single_transformer_blocks\.\d+\.proj_out",
],
"skip": [
r"single_transformer_blocks\.\d+\.attn",
r"single_transformer_blocks\.\d+\.proj_mlp",
r"single_transformer_blocks\.\d+\.act_mlp",
],
},
}
_TAYLORSEER_CACHE_HOOK = "taylorseer_cache"
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
"^single_transformer_blocks.*attn",
)
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
_BLOCK_IDENTIFIERS = (
"^[^.]*block[^.]*\\.[^.]+$",
)
_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
@dataclass
class TaylorSeerCacheConfig:
@@ -43,37 +32,29 @@ class TaylorSeerCacheConfig:
Attributes:
warmup_steps (`int`, defaults to `3`):
Number of *outer* diffusion steps to run with full computation
Number of denoising steps to run with full computation
before enabling caching. During warmup, the Taylor series factors
are still updated, but no predictions are used.
predict_steps (`int`, defaults to `5`):
Number of prediction (cached) steps to take between two full
computations. That is, once a module state is refreshed, it will
be reused for `predict_steps` subsequent outer steps, then a new
be reused for `predict_steps` subsequent denoising steps, then a new
full forward will be computed on the next step.
stop_predicts (`int`, *optional*, defaults to `None`):
Outer diffusion step index at which caching is disabled.
If provided, for `true_step >= stop_predicts` all modules are
Denoising step index at which caching is disabled.
If provided, for `self.current_step >= stop_predicts` all modules are
evaluated normally (no predictions, no state updates).
max_order (`int`, defaults to `1`):
Maximum order of Taylor series expansion to approximate the
features. Higher order gives closer approximation but more compute.
num_inner_loops (`int`, defaults to `1`):
Number of inner loops per outer diffusion step. For example,
with classifier-free guidance (CFG) you typically have 2 inner
loops: unconditional and conditional branches.
taylor_factors_dtype (`torch.dtype`, defaults to `torch.float32`):
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type for computing Taylor series expansion factors.
architecture (`str`, *optional*, defaults to `None`):
If provided, will look up default `cache` and `skip` regex
patterns in `_CACHE_TEMPLATES[architecture]`. These can be
overridden by `skip_identifiers` and `cache_identifiers`.
Use lower precision to reduce memory usage.
Use higher precision to improve numerical stability.
skip_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (fullmatch) for module names to be placed in
@@ -85,10 +66,12 @@ class TaylorSeerCacheConfig:
Regex patterns (fullmatch) for module names to be placed in
Taylor-series caching mode.
lite (`bool`, *optional*, defaults to `False`):
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides
any user-provided `skip_identifiers` or `cache_identifiers` patterns.
Notes:
- Patterns are applied with `re.fullmatch` on `module_name`.
- If either `skip_identifiers` or `cache_identifiers` is provided
(or inferred from `architecture`), only modules matching at least
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least
one of those patterns will be hooked.
- If neither is provided, all attention-like modules will be hooked.
"""
@@ -97,11 +80,10 @@ class TaylorSeerCacheConfig:
predict_steps: int = 5
stop_predicts: Optional[int] = None
max_order: int = 1
num_inner_loops: int = 1
taylor_factors_dtype: torch.dtype = torch.float32
architecture: str | None = None
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
skip_identifiers: Optional[List[str]] = None
cache_identifiers: Optional[List[str]] = None
lite: bool = False
def __repr__(self) -> str:
return (
@@ -110,364 +92,153 @@ class TaylorSeerCacheConfig:
f"predict_steps={self.predict_steps}, "
f"stop_predicts={self.stop_predicts}, "
f"max_order={self.max_order}, "
f"num_inner_loops={self.num_inner_loops}, "
f"taylor_factors_dtype={self.taylor_factors_dtype}, "
f"architecture={self.architecture}, "
f"skip_identifiers={self.skip_identifiers}, "
f"cache_identifiers={self.cache_identifiers})"
f"cache_identifiers={self.cache_identifiers}, "
f"lite={self.lite})"
)
@classmethod
def get_identifiers_template(cls) -> Dict[str, Dict[str, List[str]]]:
return _CACHE_TEMPLATES
class TaylorSeerOutputState:
"""
Manages the state for Taylor series-based prediction of a single attention output.
Tracks Taylor expansion factors, last update step, and remaining prediction steps.
The Taylor expansion uses the (outer) timestep as the independent variable.
This class is designed to handle state for a single inner loop index and a single
output (in cases where the module forward returns multiple tensors).
"""
class TaylorSeerState:
def __init__(
self,
module_name: str,
taylor_factors_dtype: torch.dtype,
module_dtype: torch.dtype,
is_skip: bool = False,
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16,
max_order: int = 1,
):
self.module_name = module_name
self.taylor_factors_dtype = taylor_factors_dtype
self.module_dtype = module_dtype
self.is_skip = is_skip
self.max_order = max_order
self.remaining_predictions: int = 0
self.module_dtypes: Tuple[torch.dtype, ...] = ()
self.last_update_step: Optional[int] = None
self.taylor_factors: Dict[int, torch.Tensor] = {}
self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {}
# For skip-mode modules
self.dummy_shape: Optional[Tuple[int, ...]] = None
self.device: Optional[torch.device] = None
self.dummy_tensor: Optional[torch.Tensor] = None
self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None
self.current_step = -1
def reset(self) -> None:
self.remaining_predictions = 0
self.last_update_step = None
self.taylor_factors = {}
self.dummy_shape = None
self.device = None
self.dummy_tensor = None
self.dummy_tensors = None
self.current_step = -1
def update(
self,
features: torch.Tensor,
outputs: Tuple[torch.Tensor, ...],
current_step: int,
max_order: int,
predict_steps: int,
) -> None:
"""
Update Taylor factors based on the current features and (outer) timestep.
self.module_dtypes = tuple(output.dtype for output in outputs)
for i in range(len(outputs)):
features = outputs[i].to(self.taylor_factors_dtype)
new_factors: Dict[int, torch.Tensor] = {0: features}
is_first_update = self.last_update_step is None
if not is_first_update:
delta_step = current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
For non-skip modules, finite difference approximations for derivatives are
computed using recursive divided differences.
Args:
features: Attention output features to update with.
current_step: Current outer timestep (true diffusion step).
max_order: Maximum Taylor expansion order.
predict_steps: Number of prediction steps to allow after this update.
"""
if self.is_skip:
# For skip modules we only need shape & device and a dummy tensor.
self.dummy_shape = features.shape
self.device = features.device
# zero is safer than uninitialized values for a "skipped" module
self.dummy_tensor = torch.zeros(
self.dummy_shape,
dtype=self.module_dtype,
device=self.device,
)
self.taylor_factors = {}
self.last_update_step = current_step
self.remaining_predictions = predict_steps
return
features = features.to(self.taylor_factors_dtype)
new_factors: Dict[int, torch.Tensor] = {0: features}
is_first_update = self.last_update_step is None
if not is_first_update:
delta_step = current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
# Recursive divided differences up to max_order
for i in range(max_order):
prev = self.taylor_factors.get(i)
if prev is None:
break
new_factors[i + 1] = (new_factors[i] - prev.to(self.taylor_factors_dtype)) / delta_step
# Keep factors in taylor_factors_dtype
self.taylor_factors = new_factors
# Recursive divided differences up to max_order
for j in range(self.max_order):
prev = self.taylor_factors[i].get(j)
if prev is None:
break
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
self.taylor_factors[i] = new_factors
self.last_update_step = current_step
self.remaining_predictions = predict_steps
if self.module_name == "proj_out":
logger.debug(
"[UPDATE] module=%s remaining_predictions=%d current_step=%d is_first_update=%s",
self.module_name,
self.remaining_predictions,
current_step,
is_first_update,
)
def predict(self, current_step: int) -> torch.Tensor:
"""
Predict features using the Taylor series at the given (outer) timestep.
Args:
current_step: Current outer timestep for prediction.
Returns:
Predicted features in the module's dtype.
"""
if self.is_skip:
if self.dummy_tensor is None:
raise ValueError("Cannot predict for skip module without prior update.")
self.remaining_predictions -= 1
return self.dummy_tensor
if self.last_update_step is None:
raise ValueError("Cannot predict without prior initialization/update.")
step_offset = current_step - self.last_update_step
output: torch.Tensor
if not self.taylor_factors:
raise ValueError("Taylor factors empty during prediction.")
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
output = torch.zeros_like(self.taylor_factors[0])
for order, factor in self.taylor_factors.items():
# Note: order starts at 0
coeff = (step_offset**order) / math.factorial(order)
output = output + factor * coeff
outputs = []
for i in range(len(self.module_dtypes)):
taylor_factors = self.taylor_factors[i]
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
output = torch.zeros_like(taylor_factors[0])
for order, factor in taylor_factors.items():
# Note: order starts at 0
coeff = (step_offset**order) / math.factorial(order)
output = output + factor * coeff
outputs.append(output.to(self.module_dtypes[i]))
self.remaining_predictions -= 1
out = output.to(self.module_dtype)
if self.module_name == "proj_out":
logger.debug(
"[PREDICT] module=%s remaining_predictions=%d current_step=%d last_update_step=%s",
self.module_name,
self.remaining_predictions,
current_step,
self.last_update_step,
)
return out
return outputs
class TaylorSeerAttentionCacheHook(ModelHook):
"""
Hook for caching and predicting attention outputs using Taylor series approximations.
Applies to attention modules in diffusion models (e.g., Flux).
Performs full computations during warmup, then alternates between blocks of
predictions and refreshes.
The hook maintains separate states for each inner loop index (e.g., for
classifier-free guidance). Each inner loop has its own list of
`TaylorSeerOutputState` instances, one per output tensor from the module's
forward (typically one).
The `step_counter` increments on every forward call of this module.
We define:
- `inner_index = step_counter % num_inner_loops`
- `true_step = step_counter // num_inner_loops`
Warmup, prediction, and updates are handled per inner loop, but use the
shared `true_step` (outer diffusion step).
"""
class TaylorSeerCacheHook(ModelHook):
_is_stateful = True
def __init__(
self,
module_name: str,
predict_steps: int,
max_order: int,
warmup_steps: int,
taylor_factors_dtype: torch.dtype,
num_inner_loops: int = 1,
state_manager: StateManager,
stop_predicts: Optional[int] = None,
is_skip: bool = False,
):
super().__init__()
if num_inner_loops <= 0:
raise ValueError("num_inner_loops must be >= 1")
self.module_name = module_name
self.predict_steps = predict_steps
self.max_order = max_order
self.warmup_steps = warmup_steps
self.stop_predicts = stop_predicts
self.num_inner_loops = num_inner_loops
self.taylor_factors_dtype = taylor_factors_dtype
self.state_manager = state_manager
self.is_skip = is_skip
self.step_counter: int = -1
self.states: Optional[List[Optional[List[TaylorSeerOutputState]]]] = None
self.num_outputs: Optional[int] = None
self.dummy_outputs = None
def initialize_hook(self, module: torch.nn.Module):
self.step_counter = -1
self.states = None
self.num_outputs = None
return module
def reset_state(self, module: torch.nn.Module) -> None:
"""
Reset state between sampling runs.
"""
self.step_counter = -1
self.states = None
self.num_outputs = None
@staticmethod
def _listify(outputs):
if isinstance(outputs, torch.Tensor):
return [outputs]
return list(outputs)
def _delistify(self, outputs_list):
if self.num_outputs == 1:
return outputs_list[0]
return tuple(outputs_list)
def _ensure_states_initialized(
self,
module: torch.nn.Module,
inner_index: int,
true_step: int,
*args,
**kwargs,
) -> Optional[List[torch.Tensor]]:
"""
Ensure per-inner-loop states exist. If this is the first call for this
inner_index, perform a full forward, initialize states, and return the
outputs. Otherwise, return None.
"""
if self.states is None:
self.states = [None for _ in range(self.num_inner_loops)]
if self.states[inner_index] is not None:
return None
if self.module_name == "proj_out":
logger.debug(
"[FIRST STEP] Initializing states for %s (inner_index=%d, true_step=%d)",
self.module_name,
inner_index,
true_step,
)
# First step for this inner loop: always full compute and initialize.
attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs))
module_dtype = attention_outputs[0].dtype
if self.num_outputs is None:
self.num_outputs = len(attention_outputs)
elif self.num_outputs != len(attention_outputs):
raise ValueError("Output count mismatch across inner loops.")
self.states[inner_index] = [
TaylorSeerOutputState(
self.module_name,
self.taylor_factors_dtype,
module_dtype,
is_skip=self.is_skip,
)
for _ in range(self.num_outputs)
]
for i, features in enumerate(attention_outputs):
self.states[inner_index][i].update(
features=features,
current_step=true_step,
max_order=self.max_order,
predict_steps=self.predict_steps,
)
return attention_outputs
self.dummy_outputs = None
self.current_step = -1
self.state_manager.reset()
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
self.step_counter += 1
inner_index = self.step_counter % self.num_inner_loops
true_step = self.step_counter // self.num_inner_loops
is_warmup_phase = true_step < self.warmup_steps
state: TaylorSeerState = self.state_manager.get_state()
state.current_step += 1
current_step = state.current_step
is_warmup_phase = current_step < self.warmup_steps
should_compute = (
is_warmup_phase
or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0)
or (self.stop_predicts is not None and current_step >= self.stop_predicts)
)
if should_compute:
outputs = self.fn_ref.original_forward(*args, **kwargs)
if not self.is_skip:
state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step)
else:
self.dummy_outputs = outputs
return outputs
if self.module_name == "proj_out":
logger.debug(
"[FORWARD] module=%s step_counter=%d inner_index=%d true_step=%d is_warmup=%s",
self.module_name,
self.step_counter,
inner_index,
true_step,
is_warmup_phase,
)
if self.is_skip:
return self.dummy_outputs
# First-time initialization for this inner loop
maybe_outputs = self._ensure_states_initialized(module, inner_index, true_step, *args, **kwargs)
if maybe_outputs is not None:
return self._delistify(maybe_outputs)
assert self.states is not None
states = self.states[inner_index]
assert states is not None and len(states) > 0
# If stop_predicts is set and we are past that step, always run full forward
if self.stop_predicts is not None and true_step >= self.stop_predicts:
attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs))
return self._delistify(attention_outputs)
# Decide between prediction vs refresh
# - Never predict during warmup.
# - Otherwise, predict while we still have remaining_predictions.
should_predict = (not is_warmup_phase) and (states[0].remaining_predictions > 0)
if should_predict:
predicted_outputs = [state.predict(true_step) for state in states]
return self._delistify(predicted_outputs)
# Full compute: warmup or refresh
attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs))
for i, features in enumerate(attention_outputs):
states[i].update(
features=features,
current_step=true_step,
max_order=self.max_order,
predict_steps=self.predict_steps,
)
return self._delistify(attention_outputs)
outputs = state.predict(current_step)
return outputs[0] if len(outputs) == 1 else outputs
def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]:
"""
Resolve effective skip and cache pattern lists from config + templates.
"""
template = _CACHE_TEMPLATES.get(config.architecture or "", {})
default_skip = template.get("skip", [])
default_cache = template.get("cache", [])
skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else default_skip
cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else default_cache
skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None
cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
return skip_patterns or [], cache_patterns or []
@@ -496,8 +267,6 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
... max_order=1,
... warmup_steps=3,
... taylor_factors_dtype=torch.float32,
... architecture="flux",
... num_inner_loops=2, # e.g. CFG
... )
>>> pipe.transformer.enable_cache(config)
```
@@ -507,67 +276,68 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
logger.debug("TaylorSeer skip identifiers: %s", skip_patterns)
logger.debug("TaylorSeer cache identifiers: %s", cache_patterns)
use_patterns = bool(skip_patterns or cache_patterns)
cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
if config.lite:
logger.info("Using TaylorSeer Lite variant for cache.")
cache_patterns = _PROJ_OUT_IDENTIFIERS
skip_patterns = _BLOCK_IDENTIFIERS
if config.skip_identifiers or config.cache_identifiers:
logger.warning("Lite mode overrides user patterns.")
for name, submodule in module.named_modules():
matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns)
matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns)
if use_patterns:
# If patterns are configured (either skip or cache), only touch modules
# that explicitly match at least one pattern.
if not (matches_skip or matches_cache):
continue
logger.debug(
"Applying TaylorSeer cache to %s (mode=%s)",
name,
"skip" if matches_skip else "cache",
)
_apply_taylorseer_cache_hook(
name=name,
module=submodule,
config=config,
is_skip=matches_skip,
)
else:
# No patterns configured: fall back to "all attention modules".
if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
logger.debug("Applying TaylorSeer cache to %s (fallback attention mode)", name)
_apply_taylorseer_cache_hook(
name=name,
module=submodule,
config=config,
is_skip=False,
)
if not (matches_skip or matches_cache):
continue
logger.debug(
"Applying TaylorSeer cache to %s (mode=%s)",
name,
"skip" if matches_skip else "cache",
)
state_manager = StateManager(
TaylorSeerState,
init_kwargs={
"taylor_factors_dtype": config.taylor_factors_dtype,
"max_order": config.max_order,
},
)
_apply_taylorseer_cache_hook(
name=name,
module=submodule,
config=config,
is_skip=matches_skip,
state_manager=state_manager,
)
def _apply_taylorseer_cache_hook(
name: str,
module: Attention,
module: nn.Module,
config: TaylorSeerCacheConfig,
is_skip: bool,
state_manager: StateManager,
):
"""
Registers the TaylorSeer hook on the specified attention module.
Registers the TaylorSeer hook on the specified nn.Module.
Args:
name: Name of the module.
module: The attention-like module to be hooked.
module: The nn.Module to be hooked.
config: Cache configuration.
is_skip: Whether this module should operate in "skip" mode.
state_manager: The state manager for managing hook state.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = TaylorSeerAttentionCacheHook(
hook = TaylorSeerCacheHook(
module_name=name,
predict_steps=config.predict_steps,
max_order=config.max_order,
warmup_steps=config.warmup_steps,
taylor_factors_dtype=config.taylor_factors_dtype,
num_inner_loops=config.num_inner_loops,
stop_predicts=config.stop_predicts,
is_skip=is_skip,
state_manager=state_manager,
)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)

View File

@@ -97,7 +97,7 @@ class CacheMixin:
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
@@ -113,7 +113,7 @@ class CacheMixin:
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True)
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")