mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add stop_predicts (cooldown)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, List, Dict
|
||||
from typing import Callable, Optional, List, Dict, Tuple
|
||||
from .hooks import ModelHook
|
||||
import math
|
||||
from ..models.attention import Attention
|
||||
@@ -12,23 +12,28 @@ from ..hooks import HookRegistry
|
||||
from ..utils import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_TAYLORSEER_ATTENTION_CACHE_HOOK = "taylorseer_attention_cache"
|
||||
|
||||
SPECIAL_CACHE_IDENTIFIERS = {
|
||||
"flux": [
|
||||
r"transformer_blocks\.\d+\.attn",
|
||||
r"transformer_blocks\.\d+\.ff",
|
||||
r"transformer_blocks\.\d+\.ff_context",
|
||||
r"single_transformer_blocks\.\d+\.proj_out",
|
||||
]
|
||||
}
|
||||
SKIP_COMPUTE_IDENTIFIERS = {
|
||||
"flux": [
|
||||
r"single_transformer_blocks\.\d+\.attn",
|
||||
r"single_transformer_blocks\.\d+\.proj_mlp",
|
||||
r"single_transformer_blocks\.\d+\.act_mlp",
|
||||
]
|
||||
# Predefined cache templates for optimized architectures
|
||||
_CACHE_TEMPLATES = {
|
||||
"flux": {
|
||||
"cache": [
|
||||
r"transformer_blocks\.\d+\.attn",
|
||||
r"transformer_blocks\.\d+\.ff",
|
||||
r"transformer_blocks\.\d+\.ff_context",
|
||||
r"single_transformer_blocks\.\d+\.proj_out",
|
||||
],
|
||||
"skip": [
|
||||
r"single_transformer_blocks\.\d+\.attn",
|
||||
r"single_transformer_blocks\.\d+\.proj_mlp",
|
||||
r"single_transformer_blocks\.\d+\.act_mlp",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -41,24 +46,39 @@ class TaylorSeerCacheConfig:
|
||||
Attributes:
|
||||
warmup_steps (int, defaults to 3): Number of warmup steps without caching.
|
||||
predict_steps (int, defaults to 5): Number of prediction (cache) steps between non-cached steps.
|
||||
stop_predicts (Optional[int], defaults to None): Step after which predictions are stopped and full computation is always performed.
|
||||
max_order (int, defaults to 1): Maximum order of Taylor series expansion to approximate the features.
|
||||
taylor_factors_dtype (torch.dtype, defaults to torch.float32): Data type for Taylor series expansion factors.
|
||||
architecture (str, defaults to None): Architecture for which the cache is applied. If we know the architecture, we can use the special cache identifiers.
|
||||
skip_compute_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
|
||||
special_cache_identifiers (List[str], defaults to []): Identifiers for modules to use special cache.
|
||||
skip_identifiers (List[str], defaults to []): Identifiers for modules to skip computation.
|
||||
cache_identifiers (List[str], defaults to []): Identifiers for modules to cache.
|
||||
|
||||
By default, this approximation can be applied to all attention modules, but in some architectures, where the outputs of attention modules are not used for any residual computation, we can skip this attention cache step, so we have to identify the next modules to cache.
|
||||
Example:
|
||||
```python
|
||||
...
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
attn_output = self.attention(x) # mark this attention module to skip computation
|
||||
ffn_output = self.ffn(attn_output) # ffn_output will be cached
|
||||
return ffn_output
|
||||
```
|
||||
"""
|
||||
|
||||
warmup_steps: int = 3
|
||||
predict_steps: int = 5
|
||||
stop_predicts: Optional[int] = None
|
||||
max_order: int = 1
|
||||
taylor_factors_dtype: torch.dtype = torch.float32
|
||||
architecture: str | None = None
|
||||
skip_compute_identifiers: List[str] = None
|
||||
special_cache_identifiers: List[str] = None
|
||||
skip_identifiers: List[str] = None
|
||||
cache_identifiers: List[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_compute_identifiers={self.skip_compute_identifiers}, special_cache_identifiers={self.special_cache_identifiers})"
|
||||
return f"TaylorSeerCacheConfig(warmup_steps={self.warmup_steps}, predict_steps={self.predict_steps}, stop_predicts={self.stop_predicts}, max_order={self.max_order}, taylor_factors_dtype={self.taylor_factors_dtype}, architecture={self.architecture}, skip_identifiers={self.skip_identifiers}, cache_identifiers={self.cache_identifiers})"
|
||||
|
||||
@classmethod
|
||||
def get_identifiers_template(self) -> Dict[str, Dict[str, List[str]]]:
|
||||
return _CACHE_TEMPLATES
|
||||
|
||||
class TaylorSeerOutputState:
|
||||
"""
|
||||
@@ -174,18 +194,20 @@ class TaylorSeerAttentionCacheHook(ModelHook):
|
||||
max_order: int,
|
||||
warmup_steps: int,
|
||||
taylor_factors_dtype: torch.dtype,
|
||||
is_skip_compute: bool = False,
|
||||
stop_predicts: Optional[int] = None,
|
||||
is_skip: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.module_name = module_name
|
||||
self.predict_steps = predict_steps
|
||||
self.max_order = max_order
|
||||
self.warmup_steps = warmup_steps
|
||||
self.stop_predicts = stop_predicts
|
||||
self.step_counter = -1
|
||||
self.states: Optional[List[TaylorSeerOutputState]] = None
|
||||
self.num_outputs: Optional[int] = None
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.is_skip_compute = is_skip_compute
|
||||
self.is_skip = is_skip
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module):
|
||||
self.step_counter = -1
|
||||
@@ -208,7 +230,7 @@ class TaylorSeerAttentionCacheHook(ModelHook):
|
||||
self.num_outputs = len(attention_outputs)
|
||||
self.states = [
|
||||
TaylorSeerOutputState(
|
||||
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip_compute
|
||||
self.module_name, self.taylor_factors_dtype, module_dtype, is_skip=self.is_skip
|
||||
)
|
||||
for _ in range(self.num_outputs)
|
||||
]
|
||||
@@ -218,22 +240,31 @@ class TaylorSeerAttentionCacheHook(ModelHook):
|
||||
)
|
||||
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
|
||||
|
||||
should_predict = self.states[0].remaining_predictions > 0
|
||||
if is_warmup_phase or not should_predict:
|
||||
# Full compute during warmup or when refresh needed
|
||||
if self.stop_predicts is not None and self.step_counter >= self.stop_predicts:
|
||||
# After stop_predicts: always full compute without updating state
|
||||
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if isinstance(attention_outputs, torch.Tensor):
|
||||
attention_outputs = [attention_outputs]
|
||||
else:
|
||||
attention_outputs = list(attention_outputs)
|
||||
is_first_update = self.step_counter == 0 # Only True for the very first step
|
||||
for i, features in enumerate(attention_outputs):
|
||||
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
|
||||
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
|
||||
else:
|
||||
# Predict using Taylor series
|
||||
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
|
||||
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
|
||||
should_predict = self.states[0].remaining_predictions > 0
|
||||
if is_warmup_phase or not should_predict:
|
||||
# Full compute during warmup or when refresh needed
|
||||
attention_outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if isinstance(attention_outputs, torch.Tensor):
|
||||
attention_outputs = [attention_outputs]
|
||||
else:
|
||||
attention_outputs = list(attention_outputs)
|
||||
is_first_update = self.step_counter == 0 # Only True for the very first step
|
||||
for i, features in enumerate(attention_outputs):
|
||||
self.states[i].update(features, self.step_counter, self.max_order, self.predict_steps, is_first_update)
|
||||
return attention_outputs[0] if self.num_outputs == 1 else tuple(attention_outputs)
|
||||
else:
|
||||
# Predict using Taylor series
|
||||
predicted_outputs = [state.predict(self.step_counter) for state in self.states]
|
||||
return predicted_outputs[0] if self.num_outputs == 1 else tuple(predicted_outputs)
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> None:
|
||||
self.states = None
|
||||
@@ -259,23 +290,23 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
|
||||
>>> apply_taylorseer_cache(pipe.transformer, config)
|
||||
```
|
||||
"""
|
||||
if config.skip_compute_identifiers:
|
||||
skip_compute_identifiers = config.skip_compute_identifiers
|
||||
if config.skip_identifiers:
|
||||
skip_identifiers = config.skip_identifiers
|
||||
else:
|
||||
skip_compute_identifiers = SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
|
||||
skip_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", [])
|
||||
|
||||
if config.special_cache_identifiers:
|
||||
special_cache_identifiers = config.special_cache_identifiers
|
||||
if config.cache_identifiers:
|
||||
cache_identifiers = config.cache_identifiers
|
||||
else:
|
||||
special_cache_identifiers = SPECIAL_CACHE_IDENTIFIERS.get(config.architecture, [])
|
||||
cache_identifiers = _CACHE_TEMPLATES.get(config.architecture, {}).get("cache", [])
|
||||
|
||||
logger.debug(f"Skip compute identifiers: {skip_compute_identifiers}")
|
||||
logger.debug(f"Special cache identifiers: {special_cache_identifiers}")
|
||||
logger.debug(f"Skip identifiers: {skip_identifiers}")
|
||||
logger.debug(f"Cache identifiers: {cache_identifiers}")
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
if (skip_compute_identifiers and special_cache_identifiers) or (special_cache_identifiers):
|
||||
if any(re.fullmatch(identifier, name) for identifier in skip_compute_identifiers) or any(
|
||||
re.fullmatch(identifier, name) for identifier in special_cache_identifiers
|
||||
if (skip_identifiers and cache_identifiers) or (cache_identifiers):
|
||||
if any(re.fullmatch(identifier, name) for identifier in skip_identifiers) or any(
|
||||
re.fullmatch(identifier, name) for identifier in cache_identifiers
|
||||
):
|
||||
logger.debug(f"Applying TaylorSeer cache to {name}")
|
||||
_apply_taylorseer_cache_hook(name, submodule, config)
|
||||
@@ -293,8 +324,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
|
||||
config (TaylorSeerCacheConfig): Configuration for the cache.
|
||||
"""
|
||||
|
||||
is_skip_compute = any(
|
||||
re.fullmatch(identifier, name) for identifier in SKIP_COMPUTE_IDENTIFIERS.get(config.architecture, [])
|
||||
is_skip = any(
|
||||
re.fullmatch(identifier, name) for identifier in _CACHE_TEMPLATES.get(config.architecture, {}).get("skip", [])
|
||||
)
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
@@ -305,7 +336,8 @@ def _apply_taylorseer_cache_hook(name: str, module: Attention, config: TaylorSee
|
||||
config.max_order,
|
||||
config.warmup_steps,
|
||||
config.taylor_factors_dtype,
|
||||
is_skip_compute=is_skip_compute,
|
||||
stop_predicts=config.stop_predicts,
|
||||
is_skip=is_skip,
|
||||
)
|
||||
|
||||
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
|
||||
registry.register_hook(hook, _TAYLORSEER_ATTENTION_CACHE_HOOK)
|
||||
Reference in New Issue
Block a user