mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refractor, add docs
This commit is contained in:
@@ -169,10 +169,12 @@ else:
|
||||
"LayerSkipConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
]
|
||||
)
|
||||
_import_structure["models"].extend(
|
||||
@@ -883,10 +885,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
from .models import (
|
||||
AllegroTransformer3DModel,
|
||||
|
||||
@@ -25,3 +25,4 @@ if is_torch_available():
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
@@ -1,9 +1,6 @@
|
||||
# Experimental hook for TaylorSeer cache
|
||||
# Supports Flux only for now
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
from typing import Callable, Optional, List, Dict
|
||||
from .hooks import ModelHook
|
||||
import math
|
||||
from ..models.attention import Attention
|
||||
@@ -13,118 +10,219 @@ from ._common import (
|
||||
)
|
||||
from ..hooks import HookRegistry
|
||||
from ..utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
|
||||
|
||||
@dataclass
|
||||
class TaylorSeerCacheConfig:
|
||||
warmup_steps: int = 3 # full compute some first steps
|
||||
fresh_threshold: int = 5 # interleave cache and compute: `fresh_threshold` steps are cached, then 1 full compute step is performed
|
||||
max_order: int = 1 # order of Taylor series expansion
|
||||
current_timestep_callback: Callable[[], int] = None
|
||||
"""
|
||||
Configuration for TaylorSeer cache.
|
||||
See: https://huggingface.co/papers/2503.06923
|
||||
|
||||
class TaylorSeerState:
|
||||
def __init__(self):
|
||||
self.predict_counter: int = 0
|
||||
self.last_step: int = 1000
|
||||
self.taylor_factors: dict[int, torch.Tensor] = {}
|
||||
Attributes:
|
||||
warmup_steps (int, defaults to 3): Number of warmup steps without caching.
|
||||
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
|
||||
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
|
||||
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
|
||||
"""
|
||||
warmup_steps: int = 3
|
||||
predict_steps: int = 5
|
||||
max_order: int = 1
|
||||
taylor_factors_dtype: torch.dtype = torch.float32
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})"
|
||||
|
||||
class TaylorSeerOutputState:
|
||||
"""
|
||||
Manages the state for Taylor series-based prediction of a single attention output.
|
||||
Tracks Taylor expansion factors, last update step, and remaining prediction steps.
|
||||
The Taylor expansion uses the timestep as the independent variable for approximation.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype):
|
||||
self.module_name = module_name
|
||||
self.remaining_predictions: int = 0
|
||||
self.last_update_step: Optional[int] = None
|
||||
self.taylor_factors: Dict[int, torch.Tensor] = {}
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.module_dtype = module_dtype
|
||||
|
||||
def reset(self):
|
||||
self.predict_counter = 0
|
||||
self.last_step = 1000
|
||||
self.remaining_predictions = 0
|
||||
self.last_update_step = None
|
||||
self.taylor_factors = {}
|
||||
|
||||
def update(self, features: torch.Tensor, current_step: int, max_order: int, refresh_threshold: int):
|
||||
logger.debug("="*10)
|
||||
N = self.last_step - current_step
|
||||
logger.debug(f"update: N: {N}, current_step: {current_step}, last_step: {self.last_step}")
|
||||
# initialize the first order taylor factors
|
||||
new_taylor_factors = {0: features}
|
||||
for i in range(max_order):
|
||||
if (self.taylor_factors.get(i) is not None) and current_step > 1:
|
||||
new_taylor_factors[i+1] = (self.taylor_factors[i] - new_taylor_factors[i]) / N
|
||||
else:
|
||||
break
|
||||
self.taylor_factors = new_taylor_factors
|
||||
self.last_step = current_step
|
||||
self.predict_counter = refresh_threshold
|
||||
logger.debug(f"last_step: {self.last_step}")
|
||||
logger.debug(f"predict_counter: {self.predict_counter}")
|
||||
logger.debug("="*10)
|
||||
def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool):
|
||||
"""
|
||||
Updates the Taylor factors based on the current features and timestep.
|
||||
Computes finite difference approximations for derivatives using recursive divided differences.
|
||||
|
||||
def predict(self, current_step: int):
|
||||
k = current_step - self.last_step
|
||||
Args:
|
||||
features (torch.Tensor): The attention output features to update with.
|
||||
current_step (int): The current timestep or step number from the diffusion model.
|
||||
max_order (int): Maximum order of the Taylor expansion.
|
||||
predict_steps (int): Number of prediction steps to set after update.
|
||||
is_first_update (bool): Whether this is the initial update (skips difference computation).
|
||||
"""
|
||||
features = features.to(self.taylor_factors_dtype)
|
||||
new_factors = {0: features}
|
||||
if not is_first_update:
|
||||
if self.last_update_step is None:
|
||||
raise ValueError("Cannot update without prior initialization.")
|
||||
delta_step = current_step - self.last_update_step
|
||||
if delta_step == 0:
|
||||
raise ValueError("Delta step cannot be zero for updates.")
|
||||
for i in range(max_order):
|
||||
if i in self.taylor_factors:
|
||||
# Finite difference: (current - previous) / delta for forward approximation
|
||||
new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step
|
||||
|
||||
# taylor factors will be kept in the taylor_factors_dtype
|
||||
self.taylor_factors = new_factors
|
||||
self.last_update_step = current_step
|
||||
self.remaining_predictions = predict_steps
|
||||
|
||||
def predict(self, current_step: int) -> torch.Tensor:
|
||||
"""
|
||||
Predicts the features using the Taylor series expansion at the given timestep.
|
||||
|
||||
Args:
|
||||
current_step (int): The current timestep for prediction.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The predicted features in the module's dtype.
|
||||
"""
|
||||
if self.last_update_step is None:
|
||||
raise ValueError("Cannot predict without prior update.")
|
||||
step_offset = current_step - self.last_update_step
|
||||
device = self.taylor_factors[0].device
|
||||
output = torch.zeros_like(self.taylor_factors[0], device=device)
|
||||
for i in range(len(self.taylor_factors)):
|
||||
output += self.taylor_factors[i] * (k ** i) / math.factorial(i)
|
||||
self.predict_counter -= 1
|
||||
logger.debug(f"predict_counter: {self.predict_counter}")
|
||||
logger.debug(f"k: {k}")
|
||||
return output
|
||||
output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype)
|
||||
for order in range(len(self.taylor_factors)):
|
||||
output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order)
|
||||
self.remaining_predictions -= 1
|
||||
# output will be converted to the module's dtype
|
||||
return output.to(self.module_dtype)
|
||||
|
||||
class TaylorSeerAttentionCacheHook(ModelHook):
|
||||
"""
|
||||
Hook for caching and predicting attention outputs using Taylor series approximations.
|
||||
Applies to attention modules in diffusion models (e.g., Flux).
|
||||
Performs full computations during warmup, then alternates between predictions and refreshes.
|
||||
"""
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, fresh_threshold: int, max_order: int, current_timestep_callback: Callable[[], int], warmup_steps: int):
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
predict_steps: int,
|
||||
max_order: int,
|
||||
warmup_steps: int,
|
||||
taylor_factors_dtype: torch.dtype,
|
||||
module_dtype: torch.dtype = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.fresh_threshold = fresh_threshold
|
||||
self.module_name = module_name
|
||||
self.predict_steps = predict_steps
|
||||
self.max_order = max_order
|
||||
self.current_timestep_callback = current_timestep_callback
|
||||
self.warmup_steps = warmup_steps
|
||||
self.step_counter = -1
|
||||
self.states: Optional[List[TaylorSeerOutputState]] = None
|
||||
self.num_outputs: Optional[int] = None
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.module_dtype = module_dtype
|
||||
|
||||
def initialize_hook(self, module):
|
||||
def initialize_hook(self, module: torch.nn.Module):
|
||||
self.step_counter = -1
|
||||
self.states = None
|
||||
self.num_outputs = None
|
||||
self.warmup_steps_counter = 0
|
||||
self.module_dtype = None
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
current_step = self.current_timestep_callback()
|
||||
assert current_step is not None, "timestep is required for TaylorSeerAttentionCacheHook"
|
||||
self.step_counter += 1
|
||||
is_warmup_phase = self.step_counter < self.warmup_steps
|
||||
|
||||
if self.states is None:
|
||||
# First step: always full compute and initialize
|
||||
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if isinstance(attention_outputs, torch.Tensor):
|
||||
attention_outputs = [attention_outputs]
|
||||
else:
|
||||
attention_outputs = list(attention_outputs)
|
||||
module_dtype = attention_outputs[0].dtype
|
||||
self.num_outputs = len(attention_outputs)
|
||||
self.states = [TaylorSeerState() for _ in range(self.num_outputs)]
|
||||
for i, feat in enumerate(attention_outputs):
|
||||
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
|
||||
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
|
||||
self.states = [
|
||||
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype)
|
||||
for _ in range(self.num_outputs)
|
||||
]
|
||||
for i, features in enumerate(attention_outputs):
|
||||
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True)
|
||||
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
|
||||
|
||||
should_predict = self.states[0].predict_counter > 0 and self.warmup_steps_counter > self.warmup_steps
|
||||
|
||||
if not should_predict:
|
||||
should_predict = self.states[0].remaining_predictions > 0
|
||||
if is_warmup_phase or not should_predict:
|
||||
# Full compute during warmup or when refresh needed
|
||||
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if isinstance(attention_outputs, torch.Tensor):
|
||||
attention_outputs = [attention_outputs]
|
||||
for i, feat in enumerate(attention_outputs):
|
||||
self.states[i].update(feat, current_step, self.max_order, self.fresh_threshold)
|
||||
return attention_outputs[0] if len(attention_outputs) == 1 else attention_outputs
|
||||
else:
|
||||
attention_outputs = list(attention_outputs)
|
||||
is_first_update = self.step_counter == 0 # Only True for the very first step
|
||||
for i, features in enumerate(attention_outputs):
|
||||
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
|
||||
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
|
||||
else:
|
||||
predicted_outputs = [state.predict(current_step) for state in self.states]
|
||||
return predicted_outputs[0] if len(predicted_outputs) == 1 else predicted_outputs
|
||||
# Predict using Taylor series
|
||||
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
|
||||
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> None:
|
||||
if self.states is not None:
|
||||
for state in self.states:
|
||||
state.reset()
|
||||
return module
|
||||
|
||||
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
|
||||
"""
|
||||
Applies the TaylorSeer cache to given pipeline.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): The model to apply the hook to.
|
||||
config (TaylorSeerCacheConfig): Configuration for the cache.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import FluxPipeline, TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
|
||||
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32)
|
||||
>>> apply_taylorseer_cache(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
for name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
|
||||
continue
|
||||
logger.debug(f"Applying TaylorSeer cache to {name}")
|
||||
_apply_taylorseer_cache_on_attention_class(name, submodule, config)
|
||||
if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
|
||||
logger.debug(f"Applying TaylorSeer cache to {name}")
|
||||
_apply_taylorseer_cache_hook(name, submodule, config)
|
||||
|
||||
def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig):
|
||||
"""
|
||||
Registers the TaylorSeer hook on the specified attention module.
|
||||
|
||||
def _apply_taylorseer_cache_on_attention_class(name: str, module: Attention, config: TaylorSeerCacheConfig):
|
||||
_apply_taylorseer_cache_hook(module, config)
|
||||
|
||||
|
||||
def _apply_taylorseer_cache_hook(module: Attention, config: TaylorSeerCacheConfig):
|
||||
Args:
|
||||
name (str): Name of the module.
|
||||
module (Attention): The attention module.
|
||||
config (TaylorSeerCacheConfig): Configuration for the cache.
|
||||
"""
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = TaylorSeerAttentionCacheHook(config.fresh_threshold, config.max_order, config.current_timestep_callback, config.warmup_steps)
|
||||
hook = TaylorSeerAttentionCacheHook(
|
||||
name,
|
||||
config.predict_steps,
|
||||
config.max_order,
|
||||
config.warmup_steps,
|
||||
config.taylor_factors_dtype,
|
||||
)
|
||||
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
|
||||
@@ -67,9 +67,11 @@ class CacheMixin:
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
|
||||
if self.is_cache_enabled:
|
||||
@@ -83,16 +85,19 @@ class CacheMixin:
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
apply_taylorseer_cache(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, TaylorSeerCacheConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_ATTENTION_CACHE_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -107,6 +112,8 @@ class CacheMixin:
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
registry.remove_hook(_TAYLORSEER_ATTENTION_CACHE_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user