1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

rename identifiers, use more expressive taylor predict loop

This commit is contained in:
toilaluan
2025-12-03 10:43:01 +00:00
parent 475ec02d8c
commit 4fb3f53b6c

View File

@@ -51,12 +51,12 @@ class TaylorSeerCacheConfig:
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may
affect stability; higher precision improves accuracy at the cost of more memory.
inactive_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the
skip_predict_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode, the
module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape)
during prediction steps to skip computation cheaply.
active_identifiers (`List[str]`, *optional*, defaults to `None`):
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where
outputs are approximated and cached for reuse.
@@ -83,8 +83,8 @@ class TaylorSeerCacheConfig:
disable_cache_after_step: Optional[int] = None
max_order: int = 1
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
inactive_identifiers: Optional[List[str]] = None
active_identifiers: Optional[List[str]] = None
skip_predict_identifiers: Optional[List[str]] = None
cache_identifiers: Optional[List[str]] = None
use_lite_mode: bool = False
def __repr__(self) -> str:
@@ -95,8 +95,8 @@ class TaylorSeerCacheConfig:
f"disable_cache_after_step={self.disable_cache_after_step}, "
f"max_order={self.max_order}, "
f"taylor_factors_dtype={self.taylor_factors_dtype}, "
f"inactive_identifiers={self.inactive_identifiers}, "
f"active_identifiers={self.active_identifiers}, "
f"skip_predict_identifiers={self.skip_predict_identifiers}, "
f"cache_identifiers={self.cache_identifiers}, "
f"use_lite_mode={self.use_lite_mode})"
)
@@ -136,7 +136,6 @@ class TaylorSeerState:
if self.is_inactive:
self.inactive_shapes = tuple(output.shape for output in outputs)
else:
self.taylor_factors = {}
for i, features in enumerate(outputs):
new_factors: Dict[int, torch.Tensor] = {0: features}
is_first_update = self.last_update_step is None
@@ -179,17 +178,17 @@ class TaylorSeerState:
else:
if not self.taylor_factors:
raise ValueError("Taylor factors empty during prediction.")
for i in range(len(self.module_dtypes)):
num_outputs = len(self.taylor_factors)
num_orders = len(self.taylor_factors[0])
for i in range(num_outputs):
output_dtype = self.module_dtypes[i]
taylor_factors = self.taylor_factors[i]
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
output = torch.zeros_like(taylor_factors[0], dtype=output_dtype)
for order, factor in taylor_factors.items():
# Note: order starts at 0
for order in range(num_orders):
coeff = (step_offset**order) / math.factorial(order)
factor = taylor_factors[order]
output = output + factor.to(output_dtype) * coeff
outputs.append(output)
return outputs
@@ -243,8 +242,8 @@ def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[st
Resolve effective inactive and active pattern lists from config + templates.
"""
inactive_patterns = config.inactive_identifiers if config.inactive_identifiers is not None else None
active_patterns = config.active_identifiers if config.active_identifiers is not None else None
inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None
active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
return inactive_patterns or [], active_patterns or []
@@ -288,7 +287,7 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
logger.info("Using TaylorSeer Lite variant for cache.")
active_patterns = _PROJ_OUT_IDENTIFIERS
inactive_patterns = _BLOCK_IDENTIFIERS
if config.inactive_identifiers or config.active_identifiers:
if config.skip_predict_identifiers or config.cache_identifiers:
logger.warning("Lite mode overrides user patterns.")
for name, submodule in module.named_modules():