mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refractor to use state manager
This commit is contained in:
@@ -4,36 +4,25 @@ from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .hooks import ModelHook
|
||||
from ..models.attention import Attention
|
||||
from ..models.attention import AttentionModuleMixin
|
||||
from ._common import _ATTENTION_CLASSES
|
||||
from ..hooks import HookRegistry
|
||||
from .hooks import ModelHook, StateManager, HookRegistry
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
|
||||
|
||||
# Predefined cache templates for optimized architectures
|
||||
_CACHE_TEMPLATES: Dict[str, Dict[str, List[str]]] = {
|
||||
"flux": {
|
||||
"cache": [
|
||||
r"transformer_blocks\.\d+\.attn",
|
||||
r"transformer_blocks\.\d+\.ff",
|
||||
r"transformer_blocks\.\d+\.ff_context",
|
||||
r"single_transformer_blocks\.\d+\.proj_out",
|
||||
],
|
||||
"skip": [
|
||||
r"single_transformer_blocks\.\d+\.attn",
|
||||
r"single_transformer_blocks\.\d+\.proj_mlp",
|
||||
r"single_transformer_blocks\.\d+\.act_mlp",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
_TAYLORSEER_CACHE_HOOK = "taylorseer_cache"
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
|
||||
"^blocks.*attn",
|
||||
"^transformer_blocks.*attn",
|
||||
"^single_transformer_blocks.*attn",
|
||||
)
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
|
||||
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
_BLOCK_IDENTIFIERS = (
|
||||
"^[^.]*block[^.]*\\.[^.]+$",
|
||||
)
|
||||
_PROJ_OUT_IDENTIFIERS = ("^proj_out$",)
|
||||
|
||||
@dataclass
|
||||
class TaylorSeerCacheConfig:
|
||||
@@ -43,37 +32,29 @@ class TaylorSeerCacheConfig:
|
||||
|
||||
Attributes:
|
||||
warmup_steps (`int`, defaults to `3`):
|
||||
Number of *outer* diffusion steps to run with full computation
|
||||
Number of denoising steps to run with full computation
|
||||
before enabling caching. During warmup, the Taylor series factors
|
||||
are still updated, but no predictions are used.
|
||||
|
||||
predict_steps (`int`, defaults to `5`):
|
||||
Number of prediction (cached) steps to take between two full
|
||||
computations. That is, once a module state is refreshed, it will
|
||||
be reused for `predict_steps` subsequent outer steps, then a new
|
||||
be reused for `predict_steps` subsequent denoising steps, then a new
|
||||
full forward will be computed on the next step.
|
||||
|
||||
stop_predicts (`int`, *optional*, defaults to `None`):
|
||||
Outer diffusion step index at which caching is disabled.
|
||||
If provided, for `true_step >= stop_predicts` all modules are
|
||||
Denoising step index at which caching is disabled.
|
||||
If provided, for `self.current_step >= stop_predicts` all modules are
|
||||
evaluated normally (no predictions, no state updates).
|
||||
|
||||
max_order (`int`, defaults to `1`):
|
||||
Maximum order of Taylor series expansion to approximate the
|
||||
features. Higher order gives closer approximation but more compute.
|
||||
|
||||
num_inner_loops (`int`, defaults to `1`):
|
||||
Number of inner loops per outer diffusion step. For example,
|
||||
with classifier-free guidance (CFG) you typically have 2 inner
|
||||
loops: unconditional and conditional branches.
|
||||
|
||||
taylor_factors_dtype (`torch.dtype`, defaults to `torch.float32`):
|
||||
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
||||
Data type for computing Taylor series expansion factors.
|
||||
|
||||
architecture (`str`, *optional*, defaults to `None`):
|
||||
If provided, will look up default `cache` and `skip` regex
|
||||
patterns in `_CACHE_TEMPLATES[architecture]`. These can be
|
||||
overridden by `skip_identifiers` and `cache_identifiers`.
|
||||
Use lower precision to reduce memory usage.
|
||||
Use higher precision to improve numerical stability.
|
||||
|
||||
skip_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (fullmatch) for module names to be placed in
|
||||
@@ -85,10 +66,12 @@ class TaylorSeerCacheConfig:
|
||||
Regex patterns (fullmatch) for module names to be placed in
|
||||
Taylor-series caching mode.
|
||||
|
||||
lite (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides
|
||||
any user-provided `skip_identifiers` or `cache_identifiers` patterns.
|
||||
Notes:
|
||||
- Patterns are applied with `re.fullmatch` on `module_name`.
|
||||
- If either `skip_identifiers` or `cache_identifiers` is provided
|
||||
(or inferred from `architecture`), only modules matching at least
|
||||
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least
|
||||
one of those patterns will be hooked.
|
||||
- If neither is provided, all attention-like modules will be hooked.
|
||||
"""
|
||||
@@ -97,11 +80,10 @@ class TaylorSeerCacheConfig:
|
||||
predict_steps: int = 5
|
||||
stop_predicts: Optional[int] = None
|
||||
max_order: int = 1
|
||||
num_inner_loops: int = 1
|
||||
taylor_factors_dtype: torch.dtype = torch.float32
|
||||
architecture: str | None = None
|
||||
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
|
||||
skip_identifiers: Optional[List[str]] = None
|
||||
cache_identifiers: Optional[List[str]] = None
|
||||
lite: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@@ -110,364 +92,153 @@ class TaylorSeerCacheConfig:
|
||||
f"predict_steps={self.predict_steps}, "
|
||||
f"stop_predicts={self.stop_predicts}, "
|
||||
f"max_order={self.max_order}, "
|
||||
f"num_inner_loops={self.num_inner_loops}, "
|
||||
f"taylor_factors_dtype={self.taylor_factors_dtype}, "
|
||||
f"architecture={self.architecture}, "
|
||||
f"skip_identifiers={self.skip_identifiers}, "
|
||||
f"cache_identifiers={self.cache_identifiers})"
|
||||
f"cache_identifiers={self.cache_identifiers}, "
|
||||
f"lite={self.lite})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_identifiers_template(cls) -> Dict[str, Dict[str, List[str]]]:
|
||||
return _CACHE_TEMPLATES
|
||||
|
||||
|
||||
class TaylorSeerOutputState:
|
||||
"""
|
||||
Manages the state for Taylor series-based prediction of a single attention output.
|
||||
|
||||
Tracks Taylor expansion factors, last update step, and remaining prediction steps.
|
||||
The Taylor expansion uses the (outer) timestep as the independent variable.
|
||||
|
||||
This class is designed to handle state for a single inner loop index and a single
|
||||
output (in cases where the module forward returns multiple tensors).
|
||||
"""
|
||||
|
||||
class TaylorSeerState:
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
taylor_factors_dtype: torch.dtype,
|
||||
module_dtype: torch.dtype,
|
||||
is_skip: bool = False,
|
||||
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
max_order: int = 1,
|
||||
):
|
||||
self.module_name = module_name
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.module_dtype = module_dtype
|
||||
self.is_skip = is_skip
|
||||
self.max_order = max_order
|
||||
|
||||
self.remaining_predictions: int = 0
|
||||
self.module_dtypes: Tuple[torch.dtype, ...] = ()
|
||||
self.last_update_step: Optional[int] = None
|
||||
self.taylor_factors: Dict[int, torch.Tensor] = {}
|
||||
self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
# For skip-mode modules
|
||||
self.dummy_shape: Optional[Tuple[int, ...]] = None
|
||||
self.device: Optional[torch.device] = None
|
||||
self.dummy_tensor: Optional[torch.Tensor] = None
|
||||
self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
|
||||
self.current_step = -1
|
||||
|
||||
def reset(self) -> None:
|
||||
self.remaining_predictions = 0
|
||||
self.last_update_step = None
|
||||
self.taylor_factors = {}
|
||||
self.dummy_shape = None
|
||||
self.device = None
|
||||
self.dummy_tensor = None
|
||||
self.dummy_tensors = None
|
||||
self.current_step = -1
|
||||
|
||||
def update(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
outputs: Tuple[torch.Tensor, ...],
|
||||
current_step: int,
|
||||
max_order: int,
|
||||
predict_steps: int,
|
||||
) -> None:
|
||||
"""
|
||||
Update Taylor factors based on the current features and (outer) timestep.
|
||||
self.module_dtypes = tuple(output.dtype for output in outputs)
|
||||
for i in range(len(outputs)):
|
||||
features = outputs[i].to(self.taylor_factors_dtype)
|
||||
new_factors: Dict[int, torch.Tensor] = {0: features}
|
||||
is_first_update = self.last_update_step is None
|
||||
if not is_first_update:
|
||||
delta_step = current_step - self.last_update_step
|
||||
if delta_step == 0:
|
||||
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
|
||||
|
||||
For non-skip modules, finite difference approximations for derivatives are
|
||||
computed using recursive divided differences.
|
||||
|
||||
Args:
|
||||
features: Attention output features to update with.
|
||||
current_step: Current outer timestep (true diffusion step).
|
||||
max_order: Maximum Taylor expansion order.
|
||||
predict_steps: Number of prediction steps to allow after this update.
|
||||
"""
|
||||
if self.is_skip:
|
||||
# For skip modules we only need shape & device and a dummy tensor.
|
||||
self.dummy_shape = features.shape
|
||||
self.device = features.device
|
||||
# zero is safer than uninitialized values for a "skipped" module
|
||||
self.dummy_tensor = torch.zeros(
|
||||
self.dummy_shape,
|
||||
dtype=self.module_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.taylor_factors = {}
|
||||
self.last_update_step = current_step
|
||||
self.remaining_predictions = predict_steps
|
||||
return
|
||||
|
||||
features = features.to(self.taylor_factors_dtype)
|
||||
new_factors: Dict[int, torch.Tensor] = {0: features}
|
||||
|
||||
is_first_update = self.last_update_step is None
|
||||
|
||||
if not is_first_update:
|
||||
delta_step = current_step - self.last_update_step
|
||||
if delta_step == 0:
|
||||
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
|
||||
|
||||
# Recursive divided differences up to max_order
|
||||
for i in range(max_order):
|
||||
prev = self.taylor_factors.get(i)
|
||||
if prev is None:
|
||||
break
|
||||
new_factors[i + 1] = (new_factors[i] - prev.to(self.taylor_factors_dtype)) / delta_step
|
||||
|
||||
# Keep factors in taylor_factors_dtype
|
||||
self.taylor_factors = new_factors
|
||||
# Recursive divided differences up to max_order
|
||||
for j in range(self.max_order):
|
||||
prev = self.taylor_factors[i].get(j)
|
||||
if prev is None:
|
||||
break
|
||||
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
|
||||
self.taylor_factors[i] = new_factors
|
||||
self.last_update_step = current_step
|
||||
self.remaining_predictions = predict_steps
|
||||
|
||||
if self.module_name == "proj_out":
|
||||
logger.debug(
|
||||
"[UPDATE] module=%s remaining_predictions=%d current_step=%d is_first_update=%s",
|
||||
self.module_name,
|
||||
self.remaining_predictions,
|
||||
current_step,
|
||||
is_first_update,
|
||||
)
|
||||
|
||||
def predict(self, current_step: int) -> torch.Tensor:
|
||||
"""
|
||||
Predict features using the Taylor series at the given (outer) timestep.
|
||||
|
||||
Args:
|
||||
current_step: Current outer timestep for prediction.
|
||||
|
||||
Returns:
|
||||
Predicted features in the module's dtype.
|
||||
"""
|
||||
if self.is_skip:
|
||||
if self.dummy_tensor is None:
|
||||
raise ValueError("Cannot predict for skip module without prior update.")
|
||||
self.remaining_predictions -= 1
|
||||
return self.dummy_tensor
|
||||
|
||||
if self.last_update_step is None:
|
||||
raise ValueError("Cannot predict without prior initialization/update.")
|
||||
|
||||
step_offset = current_step - self.last_update_step
|
||||
|
||||
output: torch.Tensor
|
||||
if not self.taylor_factors:
|
||||
raise ValueError("Taylor factors empty during prediction.")
|
||||
|
||||
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
|
||||
output = torch.zeros_like(self.taylor_factors[0])
|
||||
for order, factor in self.taylor_factors.items():
|
||||
# Note: order starts at 0
|
||||
coeff = (step_offset**order) / math.factorial(order)
|
||||
output = output + factor * coeff
|
||||
outputs = []
|
||||
for i in range(len(self.module_dtypes)):
|
||||
taylor_factors = self.taylor_factors[i]
|
||||
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
|
||||
output = torch.zeros_like(taylor_factors[0])
|
||||
for order, factor in taylor_factors.items():
|
||||
# Note: order starts at 0
|
||||
coeff = (step_offset**order) / math.factorial(order)
|
||||
output = output + factor * coeff
|
||||
outputs.append(output.to(self.module_dtypes[i]))
|
||||
|
||||
self.remaining_predictions -= 1
|
||||
out = output.to(self.module_dtype)
|
||||
|
||||
if self.module_name == "proj_out":
|
||||
logger.debug(
|
||||
"[PREDICT] module=%s remaining_predictions=%d current_step=%d last_update_step=%s",
|
||||
self.module_name,
|
||||
self.remaining_predictions,
|
||||
current_step,
|
||||
self.last_update_step,
|
||||
)
|
||||
|
||||
return out
|
||||
return outputs
|
||||
|
||||
|
||||
class TaylorSeerAttentionCacheHook(ModelHook):
|
||||
"""
|
||||
Hook for caching and predicting attention outputs using Taylor series approximations.
|
||||
|
||||
Applies to attention modules in diffusion models (e.g., Flux).
|
||||
Performs full computations during warmup, then alternates between blocks of
|
||||
predictions and refreshes.
|
||||
|
||||
The hook maintains separate states for each inner loop index (e.g., for
|
||||
classifier-free guidance). Each inner loop has its own list of
|
||||
`TaylorSeerOutputState` instances, one per output tensor from the module's
|
||||
forward (typically one).
|
||||
|
||||
The `step_counter` increments on every forward call of this module.
|
||||
We define:
|
||||
- `inner_index = step_counter % num_inner_loops`
|
||||
- `true_step = step_counter // num_inner_loops`
|
||||
|
||||
Warmup, prediction, and updates are handled per inner loop, but use the
|
||||
shared `true_step` (outer diffusion step).
|
||||
"""
|
||||
|
||||
class TaylorSeerCacheHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
predict_steps: int,
|
||||
max_order: int,
|
||||
warmup_steps: int,
|
||||
taylor_factors_dtype: torch.dtype,
|
||||
num_inner_loops: int = 1,
|
||||
state_manager: StateManager,
|
||||
stop_predicts: Optional[int] = None,
|
||||
is_skip: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if num_inner_loops <= 0:
|
||||
raise ValueError("num_inner_loops must be >= 1")
|
||||
|
||||
self.module_name = module_name
|
||||
self.predict_steps = predict_steps
|
||||
self.max_order = max_order
|
||||
self.warmup_steps = warmup_steps
|
||||
self.stop_predicts = stop_predicts
|
||||
self.num_inner_loops = num_inner_loops
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.state_manager = state_manager
|
||||
self.is_skip = is_skip
|
||||
|
||||
self.step_counter: int = -1
|
||||
self.states: Optional[List[Optional[List[TaylorSeerOutputState]]]] = None
|
||||
self.num_outputs: Optional[int] = None
|
||||
self.dummy_outputs = None
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module):
|
||||
self.step_counter = -1
|
||||
self.states = None
|
||||
self.num_outputs = None
|
||||
return module
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> None:
|
||||
"""
|
||||
Reset state between sampling runs.
|
||||
"""
|
||||
self.step_counter = -1
|
||||
self.states = None
|
||||
self.num_outputs = None
|
||||
|
||||
@staticmethod
|
||||
def _listify(outputs):
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
return [outputs]
|
||||
return list(outputs)
|
||||
|
||||
def _delistify(self, outputs_list):
|
||||
if self.num_outputs == 1:
|
||||
return outputs_list[0]
|
||||
return tuple(outputs_list)
|
||||
|
||||
def _ensure_states_initialized(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
inner_index: int,
|
||||
true_step: int,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Optional[List[torch.Tensor]]:
|
||||
"""
|
||||
Ensure per-inner-loop states exist. If this is the first call for this
|
||||
inner_index, perform a full forward, initialize states, and return the
|
||||
outputs. Otherwise, return None.
|
||||
"""
|
||||
if self.states is None:
|
||||
self.states = [None for _ in range(self.num_inner_loops)]
|
||||
|
||||
if self.states[inner_index] is not None:
|
||||
return None
|
||||
|
||||
if self.module_name == "proj_out":
|
||||
logger.debug(
|
||||
"[FIRST STEP] Initializing states for %s (inner_index=%d, true_step=%d)",
|
||||
self.module_name,
|
||||
inner_index,
|
||||
true_step,
|
||||
)
|
||||
|
||||
# First step for this inner loop: always full compute and initialize.
|
||||
attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs))
|
||||
module_dtype = attention_outputs[0].dtype
|
||||
|
||||
if self.num_outputs is None:
|
||||
self.num_outputs = len(attention_outputs)
|
||||
elif self.num_outputs != len(attention_outputs):
|
||||
raise ValueError("Output count mismatch across inner loops.")
|
||||
|
||||
self.states[inner_index] = [
|
||||
TaylorSeerOutputState(
|
||||
self.module_name,
|
||||
self.taylor_factors_dtype,
|
||||
module_dtype,
|
||||
is_skip=self.is_skip,
|
||||
)
|
||||
for _ in range(self.num_outputs)
|
||||
]
|
||||
|
||||
for i, features in enumerate(attention_outputs):
|
||||
self.states[inner_index][i].update(
|
||||
features=features,
|
||||
current_step=true_step,
|
||||
max_order=self.max_order,
|
||||
predict_steps=self.predict_steps,
|
||||
)
|
||||
|
||||
return attention_outputs
|
||||
self.dummy_outputs = None
|
||||
self.current_step = -1
|
||||
self.state_manager.reset()
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
self.step_counter += 1
|
||||
inner_index = self.step_counter % self.num_inner_loops
|
||||
true_step = self.step_counter // self.num_inner_loops
|
||||
is_warmup_phase = true_step < self.warmup_steps
|
||||
state: TaylorSeerState = self.state_manager.get_state()
|
||||
state.current_step += 1
|
||||
current_step = state.current_step
|
||||
is_warmup_phase = current_step < self.warmup_steps
|
||||
should_compute = (
|
||||
is_warmup_phase
|
||||
or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0)
|
||||
or (self.stop_predicts is not None and current_step >= self.stop_predicts)
|
||||
)
|
||||
if should_compute:
|
||||
outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if not self.is_skip:
|
||||
state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step)
|
||||
else:
|
||||
self.dummy_outputs = outputs
|
||||
return outputs
|
||||
|
||||
if self.module_name == "proj_out":
|
||||
logger.debug(
|
||||
"[FORWARD] module=%s step_counter=%d inner_index=%d true_step=%d is_warmup=%s",
|
||||
self.module_name,
|
||||
self.step_counter,
|
||||
inner_index,
|
||||
true_step,
|
||||
is_warmup_phase,
|
||||
)
|
||||
if self.is_skip:
|
||||
return self.dummy_outputs
|
||||
|
||||
# First-time initialization for this inner loop
|
||||
maybe_outputs = self._ensure_states_initialized(module, inner_index, true_step, *args, **kwargs)
|
||||
if maybe_outputs is not None:
|
||||
return self._delistify(maybe_outputs)
|
||||
|
||||
assert self.states is not None
|
||||
states = self.states[inner_index]
|
||||
assert states is not None and len(states) > 0
|
||||
|
||||
# If stop_predicts is set and we are past that step, always run full forward
|
||||
if self.stop_predicts is not None and true_step >= self.stop_predicts:
|
||||
attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs))
|
||||
return self._delistify(attention_outputs)
|
||||
|
||||
# Decide between prediction vs refresh
|
||||
# - Never predict during warmup.
|
||||
# - Otherwise, predict while we still have remaining_predictions.
|
||||
should_predict = (not is_warmup_phase) and (states[0].remaining_predictions > 0)
|
||||
|
||||
if should_predict:
|
||||
predicted_outputs = [state.predict(true_step) for state in states]
|
||||
return self._delistify(predicted_outputs)
|
||||
|
||||
# Full compute: warmup or refresh
|
||||
attention_outputs = self._listify(self.fn_ref.original_forward(*args, **kwargs))
|
||||
for i, features in enumerate(attention_outputs):
|
||||
states[i].update(
|
||||
features=features,
|
||||
current_step=true_step,
|
||||
max_order=self.max_order,
|
||||
predict_steps=self.predict_steps,
|
||||
)
|
||||
return self._delistify(attention_outputs)
|
||||
outputs = state.predict(current_step)
|
||||
return outputs[0] if len(outputs) == 1 else outputs
|
||||
|
||||
|
||||
def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Resolve effective skip and cache pattern lists from config + templates.
|
||||
"""
|
||||
template = _CACHE_TEMPLATES.get(config.architecture or "", {})
|
||||
default_skip = template.get("skip", [])
|
||||
default_cache = template.get("cache", [])
|
||||
|
||||
skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else default_skip
|
||||
cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else default_cache
|
||||
skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None
|
||||
cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
|
||||
|
||||
return skip_patterns or [], cache_patterns or []
|
||||
|
||||
@@ -496,8 +267,6 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
|
||||
... max_order=1,
|
||||
... warmup_steps=3,
|
||||
... taylor_factors_dtype=torch.float32,
|
||||
... architecture="flux",
|
||||
... num_inner_loops=2, # e.g. CFG
|
||||
... )
|
||||
>>> pipe.transformer.enable_cache(config)
|
||||
```
|
||||
@@ -507,67 +276,68 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
|
||||
logger.debug("TaylorSeer skip identifiers: %s", skip_patterns)
|
||||
logger.debug("TaylorSeer cache identifiers: %s", cache_patterns)
|
||||
|
||||
use_patterns = bool(skip_patterns or cache_patterns)
|
||||
cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
if config.lite:
|
||||
logger.info("Using TaylorSeer Lite variant for cache.")
|
||||
cache_patterns = _PROJ_OUT_IDENTIFIERS
|
||||
skip_patterns = _BLOCK_IDENTIFIERS
|
||||
if config.skip_identifiers or config.cache_identifiers:
|
||||
logger.warning("Lite mode overrides user patterns.")
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns)
|
||||
matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns)
|
||||
|
||||
if use_patterns:
|
||||
# If patterns are configured (either skip or cache), only touch modules
|
||||
# that explicitly match at least one pattern.
|
||||
if not (matches_skip or matches_cache):
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
"Applying TaylorSeer cache to %s (mode=%s)",
|
||||
name,
|
||||
"skip" if matches_skip else "cache",
|
||||
)
|
||||
_apply_taylorseer_cache_hook(
|
||||
name=name,
|
||||
module=submodule,
|
||||
config=config,
|
||||
is_skip=matches_skip,
|
||||
)
|
||||
else:
|
||||
# No patterns configured: fall back to "all attention modules".
|
||||
if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
|
||||
logger.debug("Applying TaylorSeer cache to %s (fallback attention mode)", name)
|
||||
_apply_taylorseer_cache_hook(
|
||||
name=name,
|
||||
module=submodule,
|
||||
config=config,
|
||||
is_skip=False,
|
||||
)
|
||||
if not (matches_skip or matches_cache):
|
||||
continue
|
||||
logger.debug(
|
||||
"Applying TaylorSeer cache to %s (mode=%s)",
|
||||
name,
|
||||
"skip" if matches_skip else "cache",
|
||||
)
|
||||
state_manager = StateManager(
|
||||
TaylorSeerState,
|
||||
init_kwargs={
|
||||
"taylor_factors_dtype": config.taylor_factors_dtype,
|
||||
"max_order": config.max_order,
|
||||
},
|
||||
)
|
||||
_apply_taylorseer_cache_hook(
|
||||
name=name,
|
||||
module=submodule,
|
||||
config=config,
|
||||
is_skip=matches_skip,
|
||||
state_manager=state_manager,
|
||||
)
|
||||
|
||||
|
||||
def _apply_taylorseer_cache_hook(
|
||||
name: str,
|
||||
module: Attention,
|
||||
module: nn.Module,
|
||||
config: TaylorSeerCacheConfig,
|
||||
is_skip: bool,
|
||||
state_manager: StateManager,
|
||||
):
|
||||
"""
|
||||
Registers the TaylorSeer hook on the specified attention module.
|
||||
Registers the TaylorSeer hook on the specified nn.Module.
|
||||
|
||||
Args:
|
||||
name: Name of the module.
|
||||
module: The attention-like module to be hooked.
|
||||
module: The nn.Module to be hooked.
|
||||
config: Cache configuration.
|
||||
is_skip: Whether this module should operate in "skip" mode.
|
||||
state_manager: The state manager for managing hook state.
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
hook = TaylorSeerAttentionCacheHook(
|
||||
hook = TaylorSeerCacheHook(
|
||||
module_name=name,
|
||||
predict_steps=config.predict_steps,
|
||||
max_order=config.max_order,
|
||||
warmup_steps=config.warmup_steps,
|
||||
taylor_factors_dtype=config.taylor_factors_dtype,
|
||||
num_inner_loops=config.num_inner_loops,
|
||||
stop_predicts=config.stop_predicts,
|
||||
is_skip=is_skip,
|
||||
state_manager=state_manager,
|
||||
)
|
||||
|
||||
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
|
||||
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
|
||||
|
||||
@@ -97,7 +97,7 @@ class CacheMixin:
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -113,7 +113,7 @@ class CacheMixin:
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True)
|
||||
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user