1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

refractor, add docs

This commit is contained in:
toilaluan
2025-11-14 07:00:12 +00:00
parent 0602044da7
commit 1099e493e6
4 changed files with 182 additions and 72 deletions

View File

@@ -169,10 +169,12 @@ else:
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
)
_import_structure["models"].extend(
@@ -883,10 +885,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
from .models import (
AllegroTransformer3DModel,

View File

@@ -25,3 +25,4 @@ if is_torch_available():
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

View File

@@ -1,9 +1,6 @@
# Experimental hook for TaylorSeer cache
# Supports Flux only for now
import torch
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Optional, List, Dict
from .hooks import ModelHook
import math
from ..models.attention import Attention
@@ -13,118 +10,219 @@ from ._common import (
)
from ..hooks import HookRegistry
from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
@dataclass
class TaylorSeerCacheConfig:
warmup_steps: int = 3 # full compute some first steps
fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
max_order: int = 1 # order of Taylor series expansion
current_timestep_callback: Callable[[], int] = None
"""
Configuration for TaylorSeer cache.
See: https://huggingface.co/papers/2503.06923
class TaylorSeerState:
def __init__(self):
self.predict_counter: int = 0
self.last_step: int = 1000
self.taylor_factors: dict[int, torch.Tensor] = {}
Attributes:
warmup_steps (int, defaults to 3): Number of warmup steps without caching.
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
"""
warmup_steps: int = 3
predict_steps: int = 5
max_order: int = 1
taylor_factors_dtype: torch.dtype = torch.float32
def __repr__(self) -> str:
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})"
class TaylorSeerOutputState:
"""
Manages the state for Taylor series-based prediction of a single attention output.
Tracks Taylor expansion factors, last update step, and remaining prediction steps.
The Taylor expansion uses the timestep as the independent variable for approximation.
"""
def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype):
self.module_name = module_name
self.remaining_predictions: int = 0
self.last_update_step: Optional[int] = None
self.taylor_factors: Dict[int, torch.Tensor] = {}
self.taylor_factors_dtype = taylor_factors_dtype
self.module_dtype = module_dtype
def reset(self):
self.predict_counter = 0
self.last_step = 1000
self.remaining_predictions = 0
self.last_update_step = None
self.taylor_factors = {}
def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int):
logger.debug("="*10)
N = self.last_step - current_step
logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}")
# initialize the first order taylor factors
new_taylor_factors = {0: features}
for i in range(max_order):
if (self.taylor_factors.get(i) is not None) and current_step > 1:
new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N
else:
break
self.taylor_factors = new_taylor_factors
self.last_step = current_step
self.predict_counter = refresh_threshold
logger.debug(f"last_step: {self.last_step}")
logger.debug(f"predict_counter: {self.predict_counter}")
logger.debug("="*10)
def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool):
"""
Updates the Taylor factors based on the current features and timestep.
Computes finite difference approximations for derivatives using recursive divided differences.
def predict(self, current_step: int):
k = current_step - self.last_step
Args:
features (torch.Tensor): The attention output features to update with.
current_step (int): The current timestep or step number from the diffusion model.
max_order (int): Maximum order of the Taylor expansion.
predict_steps (int): Number of prediction steps to set after update.
is_first_update (bool): Whether this is the initial update (skips difference computation).
"""
features = features.to(self.taylor_factors_dtype)
new_factors = {0: features}
if not is_first_update:
if self.last_update_step is None:
raise ValueError("Cannot update without prior initialization.")
delta_step = current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for updates.")
for i in range(max_order):
if i in self.taylor_factors:
# Finite difference: (current - previous) / delta for forward approximation
new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step
# taylor factors will be kept in the taylor_factors_dtype
self.taylor_factors = new_factors
self.last_update_step = current_step
self.remaining_predictions = predict_steps
def predict(self, current_step: int) -> torch.Tensor:
"""
Predicts the features using the Taylor series expansion at the given timestep.
Args:
current_step (int): The current timestep for prediction.
Returns:
torch.Tensor: The predicted features in the module's dtype.
"""
if self.last_update_step is None:
raise ValueError("Cannot predict without prior update.")
step_offset = current_step - self.last_update_step
device = self.taylor_factors[0].device
output = torch.zeros_like(self.taylor_factors[0], device=device)
for i in range(len(self.taylor_factors)):
output += self.taylor_factors[i] * (k ** i) / math.factorial(i)
self.predict_counter -= 1
logger.debug(f"predict_counter: {self.predict_counter}")
logger.debug(f"k: {k}")
return output
output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype)
for order in range(len(self.taylor_factors)):
output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order)
self.remaining_predictions -= 1
# output will be converted to the module's dtype
return output.to(self.module_dtype)
class TaylorSeerAttentionCacheHook(ModelHook):
"""
Hook for caching and predicting attention outputs using Taylor series approximations.
Applies to attention modules in diffusion models (e.g., Flux).
Performs full computations during warmup, then alternates between predictions and refreshes.
"""
_is_stateful = True
def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int):
def __init__(
self,
module_name: str,
predict_steps: int,
max_order: int,
warmup_steps: int,
taylor_factors_dtype: torch.dtype,
module_dtype: torch.dtype = None,
):
super().__init__()
self.fresh_threshold = fresh_threshold
self.module_name = module_name
self.predict_steps = predict_steps
self.max_order = max_order
self.current_timestep_callback = current_timestep_callback
self.warmup_steps = warmup_steps
self.step_counter = -1
self.states: Optional[List[TaylorSeerOutputState]] = None
self.num_outputs: Optional[int] = None
self.taylor_factors_dtype = taylor_factors_dtype
self.module_dtype = module_dtype
def initialize_hook(self, module):
def initialize_hook(self, module: torch.nn.Module):
self.step_counter = -1
self.states = None
self.num_outputs = None
self.warmup_steps_counter = 0
self.module_dtype = None
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
current_step = self.current_timestep_callback()
assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook"
self.step_counter += 1
is_warmup_phase = self.step_counter < self.warmup_steps
if self.states is None:
# First step: always full compute and initialize
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
if isinstance(attention_outputs, torch.Tensor):
attention_outputs = [attention_outputs]
else:
attention_outputs = list(attention_outputs)
module_dtype = attention_outputs[0].dtype
self.num_outputs = len(attention_outputs)
self.states = [TaylorSeerState() for _ in range(self.num_outputs)]
for i, feat in enumerate(attention_outputs):
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
self.states = [
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype)
for _ in range(self.num_outputs)
]
for i, features in enumerate(attention_outputs):
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps
if not should_predict:
should_predict = self.states[0].remaining_predictions > 0
if is_warmup_phase or not should_predict:
# Full compute during warmup or when refresh needed
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
if isinstance(attention_outputs, torch.Tensor):
attention_outputs = [attention_outputs]
for i, feat in enumerate(attention_outputs):
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
else:
attention_outputs = list(attention_outputs)
is_first_update = self.step_counter == 0 # Only True for the very first step
for i, features in enumerate(attention_outputs):
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
else:
predicted_outputs = [state.predict(current_step) for state in self.states]
return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs
# Predict using Taylor series
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
def reset_state(self, module: torch.nn.Module) -> None:
if self.states is not None:
for state in self.states:
state.reset()
return module
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
"""
Applies the TaylorSeer cache to given pipeline.
Args:
module (torch.nn.Module): The model to apply the hook to.
config (TaylorSeerCacheConfig): Configuration for the cache.
Example:
```python
>>> import torch
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32)
>>> apply_taylorseer_cache(pipe.transformer, config)
```
"""
for name, submodule in module.named_modules():
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
continue
logger.debug(f"Applying TaylorSeer cache to {name}")
_apply_taylorseer_cache_on_attention_class(name, submodule, config)
if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
logger.debug(f"Applying TaylorSeer cache to {name}")
_apply_taylorseer_cache_hook(name, submodule, config)
def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig):
"""
Registers the TaylorSeer hook on the specified attention module.
def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig):
_apply_taylorseer_cache_hook(module, config)
def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig):
Args:
name (str): Name of the module.
module (Attention): The attention module.
config (TaylorSeerCacheConfig): Configuration for the cache.
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps)
hook = TaylorSeerAttentionCacheHook(
name,
config.predict_steps,
config.max_order,
config.warmup_steps,
config.taylor_factors_dtype,
)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)

View File

@@ -67,9 +67,11 @@ class CacheMixin:
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
if self.is_cache_enabled:
@@ -83,16 +85,19 @@ class CacheMixin:
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
apply_taylorseer_cache(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
@@ -107,6 +112,8 @@ class CacheMixin:
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")