1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add stop_predicts (cooldown)

This commit is contained in:
toilaluan
2025-11-16 05:09:44 +00:00
parent 51b4318a3e
commit 7238d40dd9

View File

@@ -1,6 +1,6 @@
import torch
from dataclasses import dataclass
from typing import Callable, Optional, List, Dict
from typing import Callable, Optional, List, Dict, Tuple
from .hooks import ModelHook
import math
from ..models.attention import Attention
@@ -12,23 +12,28 @@ from ..hooks import HookRegistry
from ..utils import logging
import re
from collections import defaultdict
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
SPECIAL_CACHE_IDENTIFIERS = {
"flux": [
r"transformer_blocks\.\d+\.attn",
r"transformer_blocks\.\d+\.ff",
r"transformer_blocks\.\d+\.ff_context",
r"single_transformer_blocks\.\d+\.proj_out",
]
}
SKIP_COMPUTE_IDENTIFIERS = {
"flux": [
r"single_transformer_blocks\.\d+\.attn",
r"single_transformer_blocks\.\d+\.proj_mlp",
r"single_transformer_blocks\.\d+\.act_mlp",
]
# Predefined cache templates for optimized architectures
_CACHE_TEMPLATES = {
"flux": {
"cache": [
r"transformer_blocks\.\d+\.attn",
r"transformer_blocks\.\d+\.ff",
r"transformer_blocks\.\d+\.ff_context",
r"single_transformer_blocks\.\d+\.proj_out",
],
"skip": [
r"single_transformer_blocks\.\d+\.attn",
r"single_transformer_blocks\.\d+\.proj_mlp",
r"single_transformer_blocks\.\d+\.act_mlp",
],
},
}
@@ -41,24 +46,39 @@ class TaylorSeerCacheConfig:
Attributes:
warmup_steps (int, defaults to 3): Number of warmup steps without caching.
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
stop_predicts (Optional[int], defaults to None): Step after which predictions are stopped and full computation is always performed.
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
skip_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
cache_identifiers (List[str], defaults to []): Identifiers for modules to cache.
By default, this approximation can be applied to all attention modules, but in some architectures, where the outputs of attention modules are not used for any residual computation, we can skip this attention cache step, so we have to identify the next modules to cache.
Example:
```python
...
def forward(self, x: torch.Tensor) -> torch.Tensor:
attn_output = self.attention(x) # mark this attention module to skip computation
ffn_output = self.ffn(attn_output) # ffn_output will be cached
return ffn_output
```
"""
warmup_steps: int = 3
predict_steps: int = 5
stop_predicts: Optional[int] = None
max_order: int = 1
taylor_factors_dtype: torch.dtype = torch.float32
architecture: str | None = None
skip_compute_identifiers: List[str] = None
special_cache_identifiers: List[str] = None
skip_identifiers: List[str] = None
cache_identifiers: List[str] = None
def __repr__(self) -> str:
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})"
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, stop_predicts={self.stop_predicts}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_identifiers={self.skip_identifiers}, cache_identifiers={self.cache_identifiers})"
@classmethod
def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]:
return _CACHE_TEMPLATES
class TaylorSeerOutputState:
"""
@@ -174,18 +194,20 @@ class TaylorSeerAttentionCacheHook(ModelHook):
max_order: int,
warmup_steps: int,
taylor_factors_dtype: torch.dtype,
is_skip_compute: bool = False,
stop_predicts: Optional[int] = None,
is_skip: bool = False,
):
super().__init__()
self.module_name = module_name
self.predict_steps = predict_steps
self.max_order = max_order
self.warmup_steps = warmup_steps
self.stop_predicts = stop_predicts
self.step_counter = -1
self.states: Optional[List[TaylorSeerOutputState]] = None
self.num_outputs: Optional[int] = None
self.taylor_factors_dtype = taylor_factors_dtype
self.is_skip_compute = is_skip_compute
self.is_skip = is_skip
def initialize_hook(self, module: torch.nn.Module):
self.step_counter = -1
@@ -208,7 +230,7 @@ class TaylorSeerAttentionCacheHook(ModelHook):
self.num_outputs = len(attention_outputs)
self.states = [
TaylorSeerOutputState(
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip
)
for _ in range(self.num_outputs)
]
@@ -218,22 +240,31 @@ class TaylorSeerAttentionCacheHook(ModelHook):
)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
should_predict = self.states[0].remaining_predictions > 0
if is_warmup_phase or not should_predict:
# Full compute during warmup or when refresh needed
if self.stop_predicts is not None and self.step_counter >= self.stop_predicts:
# After stop_predicts: always full compute without updating state
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
if isinstance(attention_outputs, torch.Tensor):
attention_outputs = [attention_outputs]
else:
attention_outputs = list(attention_outputs)
is_first_update = self.step_counter == 0 # Only True for the very first step
for i, features in enumerate(attention_outputs):
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
else:
# Predict using Taylor series
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
should_predict = self.states[0].remaining_predictions > 0
if is_warmup_phase or not should_predict:
# Full compute during warmup or when refresh needed
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
if isinstance(attention_outputs, torch.Tensor):
attention_outputs = [attention_outputs]
else:
attention_outputs = list(attention_outputs)
is_first_update = self.step_counter == 0 # Only True for the very first step
for i, features in enumerate(attention_outputs):
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
else:
# Predict using Taylor series
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
def reset_state(self, module: torch.nn.Module) -> None:
self.states = None
@@ -259,23 +290,23 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
>>> apply_taylorseer_cache(pipe.transformer, config)
```
"""
if config.skip_compute_identifiers:
skip_compute_identifiers = config.skip_compute_identifiers
if config.skip_identifiers:
skip_identifiers = config.skip_identifiers
else:
skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
skip_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", [])
if config.special_cache_identifiers:
special_cache_identifiers = config.special_cache_identifiers
if config.cache_identifiers:
cache_identifiers = config.cache_identifiers
else:
special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, [])
cache_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("cache", [])
logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}")
logger.debug(f"Special cache identifiers: {special_cache_identifiers}")
logger.debug(f"Skip identifiers: {skip_identifiers}")
logger.debug(f"Cache identifiers: {cache_identifiers}")
for name, submodule in module.named_modules():
if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers):
if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any(
re.fullmatch(identifier, name) for identifier in special_cache_identifiers
if (skip_identifiers and cache_identifiers) or (cache_identifiers):
if any(re.fullmatch(identifier, name) for identifier in skip_identifiers) or any(
re.fullmatch(identifier, name) for identifier in cache_identifiers
):
logger.debug(f"Applying TaylorSeer cache to {name}")
_apply_taylorseer_cache_hook(name, submodule, config)
@@ -293,8 +324,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
config (TaylorSeerCacheConfig): Configuration for the cache.
"""
is_skip_compute = any(
re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
is_skip = any(
re.fullmatch(identifier, name) for identifier in _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", [])
)
registry = HookRegistry.check_if_exists_or_initialize(module)
@@ -305,7 +336,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
config.max_order,
config.warmup_steps,
config.taylor_factors_dtype,
is_skip_compute=is_skip_compute,
stop_predicts=config.stop_predicts,
is_skip=is_skip,
)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)