1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add configurable cache, skip compute module

This commit is contained in:
toilaluan
2025-11-14 09:09:46 +00:00
parent 1099e493e6
commit 7b4ad2de63

View File

@@ -10,10 +10,28 @@ from ._common import (
)
from ..hooks import HookRegistry
from ..utils import logging
import re
from collections import defaultdict
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
SPECIAL_CACHE_IDENTIFIERS = {
"flux": [
r"transformer_blocks\.\d+\.attn",
r"transformer_blocks\.\d+\.ff",
r"transformer_blocks\.\d+\.ff_context",
r"single_transformer_blocks\.\d+\.proj_out",
]
}
SKIP_COMPUTE_IDENTIFIERS = {
"flux": [
r"single_transformer_blocks\.\d+\.attn",
r"single_transformer_blocks\.\d+\.proj_mlp",
r"single_transformer_blocks\.\d+\.act_mlp",
]
}
@dataclass
class TaylorSeerCacheConfig:
"""
@@ -25,14 +43,22 @@ class TaylorSeerCacheConfig:
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
"""
warmup_steps: int = 3
predict_steps: int = 5
max_order: int = 1
taylor_factors_dtype: torch.dtype = torch.float32
architecture: str | None = None
skip_compute_identifiers: List[str] = None
special_cache_identifiers: List[str] = None
def __repr__(self) -> str:
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype})"
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})"
class TaylorSeerOutputState:
"""
@@ -41,20 +67,31 @@ class TaylorSeerOutputState:
The Taylor expansion uses the timestep as the independent variable for approximation.
"""
def __init__(self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype):
def __init__(
self, module_name: str, taylor_factors_dtype: torch.dtype, module_dtype: torch.dtype, is_skip: bool = False
):
self.module_name = module_name
self.remaining_predictions: int = 0
self.last_update_step: Optional[int] = None
self.taylor_factors: Dict[int, torch.Tensor] = {}
self.taylor_factors_dtype = taylor_factors_dtype
self.module_dtype = module_dtype
self.is_skip = is_skip
self.dummy_shape: Optional[Tuple[int, ...]] = None
self.device: Optional[torch.device] = None
self.dummy_tensor: Optional[torch.Tensor] = None
def reset(self):
self.remaining_predictions = 0
self.last_update_step = None
self.taylor_factors = {}
self.dummy_shape = None
self.device = None
self.dummy_tensor = None
def update(self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool):
def update(
self, features: torch.Tensor, current_step: int, max_order: int, predict_steps: int, is_first_update: bool
):
"""
Updates the Taylor factors based on the current features and timestep.
Computes finite difference approximations for derivatives using recursive divided differences.
@@ -66,23 +103,33 @@ class TaylorSeerOutputState:
predict_steps (int): Number of prediction steps to set after update.
is_first_update (bool): Whether this is the initial update (skips difference computation).
"""
features = features.to(self.taylor_factors_dtype)
new_factors = {0: features}
if not is_first_update:
if self.last_update_step is None:
raise ValueError("Cannot update without prior initialization.")
delta_step = current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for updates.")
for i in range(max_order):
if i in self.taylor_factors:
# Finite difference: (current - previous) / delta for forward approximation
new_factors[i + 1] = (new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)) / delta_step
if self.is_skip:
self.dummy_shape = features.shape
self.device = features.device
self.taylor_factors = {}
self.last_update_step = current_step
self.remaining_predictions = predict_steps
else:
features = features.to(self.taylor_factors_dtype)
new_factors = {0: features}
if not is_first_update:
if self.last_update_step is None:
raise ValueError("Cannot update without prior initialization.")
delta_step = current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for updates.")
for i in range(max_order):
if i in self.taylor_factors:
new_factors[i + 1] = (
new_factors[i] - self.taylor_factors[i].to(self.taylor_factors_dtype)
) / delta_step
else:
break
# taylor factors will be kept in the taylor_factors_dtype
self.taylor_factors = new_factors
self.last_update_step = current_step
self.remaining_predictions = predict_steps
# taylor factors will be kept in the taylor_factors_dtype
self.taylor_factors = new_factors
self.last_update_step = current_step
self.remaining_predictions = predict_steps
def predict(self, current_step: int) -> torch.Tensor:
"""
@@ -94,16 +141,22 @@ class TaylorSeerOutputState:
Returns:
torch.Tensor: The predicted features in the module's dtype.
"""
if self.last_update_step is None:
raise ValueError("Cannot predict without prior update.")
step_offset = current_step - self.last_update_step
device = self.taylor_factors[0].device
output = torch.zeros_like(self.taylor_factors[0], device=device, dtype=self.taylor_factors_dtype)
for order in range(len(self.taylor_factors)):
output += self.taylor_factors[order] * (step_offset ** order) / math.factorial(order)
self.remaining_predictions -= 1
# output will be converted to the module's dtype
return output.to(self.module_dtype)
if self.is_skip:
if self.dummy_shape is None or self.device is None:
raise ValueError("Cannot predict for skip module without prior update.")
self.remaining_predictions -= 1
return torch.empty(self.dummy_shape, dtype=self.module_dtype, device=self.device)
else:
if self.last_update_step is None:
raise ValueError("Cannot predict without prior update.")
step_offset = current_step - self.last_update_step
output = 0
for order in range(len(self.taylor_factors)):
output += self.taylor_factors[order] * (step_offset**order) * (1 / math.factorial(order))
self.remaining_predictions -= 1
# output will be converted to the module's dtype
return output.to(self.module_dtype)
class TaylorSeerAttentionCacheHook(ModelHook):
"""
@@ -111,6 +164,7 @@ class TaylorSeerAttentionCacheHook(ModelHook):
Applies to attention modules in diffusion models (e.g., Flux).
Performs full computations during warmup, then alternates between predictions and refreshes.
"""
_is_stateful = True
def __init__(
@@ -120,7 +174,7 @@ class TaylorSeerAttentionCacheHook(ModelHook):
max_order: int,
warmup_steps: int,
taylor_factors_dtype: torch.dtype,
module_dtype: torch.dtype = None,
is_skip_compute: bool = False,
):
super().__init__()
self.module_name = module_name
@@ -131,13 +185,12 @@ class TaylorSeerAttentionCacheHook(ModelHook):
self.states: Optional[List[TaylorSeerOutputState]] = None
self.num_outputs: Optional[int] = None
self.taylor_factors_dtype = taylor_factors_dtype
self.module_dtype = module_dtype
self.is_skip_compute = is_skip_compute
def initialize_hook(self, module: torch.nn.Module):
self.step_counter = -1
self.states = None
self.num_outputs = None
self.module_dtype = None
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@@ -154,11 +207,15 @@ class TaylorSeerAttentionCacheHook(ModelHook):
module_dtype = attention_outputs[0].dtype
self.num_outputs = len(attention_outputs)
self.states = [
TaylorSeerOutputState(self.module_name, self.taylor_factors_dtype, module_dtype)
TaylorSeerOutputState(
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute
)
for _ in range(self.num_outputs)
]
for i, features in enumerate(attention_outputs):
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True)
self.states[i].update(
features, self.step_counter, self.max_order, self.predict_steps, is_first_update=True
)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
should_predict = self.states[0].remaining_predictions > 0
@@ -179,9 +236,8 @@ class TaylorSeerAttentionCacheHook(ModelHook):
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
def reset_state(self, module: torch.nn.Module) -> None:
if self.states is not None:
for state in self.states:
state.reset()
self.states = None
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
"""
@@ -199,30 +255,57 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32)
>>> config = TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float32, architecture="flux")
>>> apply_taylorseer_cache(pipe.transformer, config)
```
"""
if config.skip_compute_identifiers:
skip_compute_identifiers = config.skip_compute_identifiers
else:
skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
if config.special_cache_identifiers:
special_cache_identifiers = config.special_cache_identifiers
else:
special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, [])
logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}")
logger.debug(f"Special cache identifiers: {special_cache_identifiers}")
for name, submodule in module.named_modules():
if isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
if skip_compute_identifiers and special_cache_identifiers:
if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any(
re.fullmatch(identifier, name) for identifier in special_cache_identifiers
):
logger.debug(f"Applying TaylorSeer cache to {name}")
_apply_taylorseer_cache_hook(name, submodule, config)
elif isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
logger.debug(f"Applying TaylorSeer cache to {name}")
_apply_taylorseer_cache_hook(name, submodule, config)
def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSeerCacheConfig):
"""
Registers the TaylorSeer hook on the specified attention module.
Args:
name (str): Name of the module.
module (Attention): The attention module.
config (TaylorSeerCacheConfig): Configuration for the cache.
"""
is_skip_compute = any(
re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
)
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = TaylorSeerAttentionCacheHook(
name,
config.predict_steps,
config.max_order,
config.warmup_steps,
config.taylor_factors_dtype,
is_skip_compute=is_skip_compute,
)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)