1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

chores: naming, remove redundancy

This commit is contained in:
toilaluan
2025-11-28 07:23:01 +00:00
parent 656c7bc501
commit 24267c76de

View File

@@ -29,64 +29,74 @@ class TaylorSeerCacheConfig:
Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
Attributes:
warmup_steps (`int`, defaults to `3`):
Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor
series factors are still updated, but no predictions are used.
cache_interval (`int`, defaults to `5`):
The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused
for this many subsequent denoising steps before refreshing with a new full forward pass.
predict_steps (`int`, defaults to `5`):
Number of prediction (cached) steps to take between two full computations. That is, once a module state is
refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will
be computed on the next step.
disable_cache_before_step (`int`, defaults to `3`):
The denoising step index before which caching is disabled, meaning full computation is performed for the initial
steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps,
Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step.
stop_predicts (`int`, *optional*, defaults to `None`):
Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts`
all modules are evaluated normally (no predictions, no state updates).
disable_cache_after_step (`int`, *optional*, defaults to `None`):
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full
computations without predictions or state updates, ensuring accuracy in later stages if needed.
max_order (`int`, defaults to `1`):
Maximum order of Taylor series expansion to approximate the features. Higher order gives closer
approximation but more compute.
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better
approximations but increase computation and memory usage.
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use
higher precision to improve numerical stability.
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect
stability; higher precision improves accuracy at the cost of more memory.
skip_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated
during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps.
inactive_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module
computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during
prediction steps to skip computation cheaply.
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode.
active_identifiers (`List[str]`, *optional*, defaults to `None`):
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs
are approximated and cached for reuse.
use_lite_mode (`bool`, *optional*, defaults to `False`):
Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom
`inactive_identifiers` or `active_identifiers`.
lite (`bool`, *optional*, defaults to `False`):
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided
`skip_identifiers` or `cache_identifiers` patterns.
Notes:
- Patterns are applied with `re.fullmatch` on `module_name`.
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those
patterns will be hooked.
- If neither is provided, all attention-like modules will be hooked.
- Patterns are matched using `re.fullmatch` on the module name.
- If `inactive_identifiers` or `active_identifiers` are provided, only matching modules are hooked.
- If neither is provided, all attention-like modules are hooked by default.
- Example of inactive and active usage:
```
def forward(x):
x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute
x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps
return x
```
"""
warmup_steps: int = 3
predict_steps: int = 5
stop_predicts: Optional[int] = None
cache_interval: int = 5
disable_cache_before_step: int = 3
disable_cache_after_step: Optional[int] = None
max_order: int = 1
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
skip_identifiers: Optional[List[str]] = None
cache_identifiers: Optional[List[str]] = None
lite: bool = False
inactive_identifiers: Optional[List[str]] = None
active_identifiers: Optional[List[str]] = None
use_lite_mode: bool = False
def __repr__(self) -> str:
return (
"TaylorSeerCacheConfig("
f"warmup_steps={self.warmup_steps}, "
f"predict_steps={self.predict_steps}, "
f"stop_predicts={self.stop_predicts}, "
f"cache_interval={self.cache_interval}, "
f"disable_cache_before_step={self.disable_cache_before_step}, "
f"disable_cache_after_step={self.disable_cache_after_step}, "
f"max_order={self.max_order}, "
f"taylor_factors_dtype={self.taylor_factors_dtype}, "
f"skip_identifiers={self.skip_identifiers}, "
f"cache_identifiers={self.cache_identifiers}, "
f"lite={self.lite})"
f"inactive_identifiers={self.inactive_identifiers}, "
f"active_identifiers={self.active_identifiers}, "
f"use_lite_mode={self.use_lite_mode})"
)
@@ -95,70 +105,88 @@ class TaylorSeerState:
self,
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16,
max_order: int = 1,
is_inactive: bool = False,
):
self.taylor_factors_dtype = taylor_factors_dtype
self.max_order = max_order
self.is_inactive = is_inactive
self.module_dtypes: Tuple[torch.dtype, ...] = ()
self.last_update_step: Optional[int] = None
self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {}
# For skip-mode modules
self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None
self.device: Optional[torch.device] = None
self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None
self.current_step: int = -1
self.current_step = -1
def reset(self) -> None:
self.current_step = -1
self.last_update_step = None
self.taylor_factors = {}
self.inactive_shapes = None
self.device = None
self.dummy_tensors = None
self.current_step = -1
def update(
self,
outputs: Tuple[torch.Tensor, ...],
current_step: int,
) -> None:
self.module_dtypes = tuple(output.dtype for output in outputs)
for i in range(len(outputs)):
features = outputs[i].to(self.taylor_factors_dtype)
new_factors: Dict[int, torch.Tensor] = {0: features}
is_first_update = self.last_update_step is None
if not is_first_update:
delta_step = current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
self.device = outputs[0].device
# Recursive divided differences up to max_order
for j in range(self.max_order):
prev = self.taylor_factors[i].get(j)
if prev is None:
break
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
self.taylor_factors[i] = new_factors
self.last_update_step = current_step
if self.is_inactive:
self.inactive_shapes = tuple(output.shape for output in outputs)
else:
self.taylor_factors = {}
for i, output in enumerate(outputs):
features = output.to(self.taylor_factors_dtype)
new_factors: Dict[int, torch.Tensor] = {0: features}
is_first_update = self.last_update_step is None
if not is_first_update:
delta_step = self.current_step - self.last_update_step
if delta_step == 0:
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
def predict(self, current_step: int) -> torch.Tensor:
# Recursive divided differences up to max_order
prev_factors = self.taylor_factors.get(i, {})
for j in range(self.max_order):
prev = prev_factors.get(j)
if prev is None:
break
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
self.taylor_factors[i] = new_factors
self.last_update_step = self.current_step
def predict(self) -> List[torch.Tensor]:
if self.last_update_step is None:
raise ValueError("Cannot predict without prior initialization/update.")
step_offset = current_step - self.last_update_step
if not self.taylor_factors:
raise ValueError("Taylor factors empty during prediction.")
step_offset = self.current_step - self.last_update_step
outputs = []
for i in range(len(self.module_dtypes)):
taylor_factors = self.taylor_factors[i]
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
output = torch.zeros_like(taylor_factors[0])
for order, factor in taylor_factors.items():
# Note: order starts at 0
coeff = (step_offset**order) / math.factorial(order)
output = output + factor * coeff
outputs.append(output.to(self.module_dtypes[i]))
if self.is_inactive:
if self.inactive_shapes is None:
raise ValueError("Inactive shapes not set during prediction.")
for i in range(len(self.module_dtypes)):
outputs.append(
torch.zeros(
self.inactive_shapes[i],
dtype=self.module_dtypes[i],
device=self.device,
)
)
else:
if not self.taylor_factors:
raise ValueError("Taylor factors empty during prediction.")
for i in range(len(self.module_dtypes)):
taylor_factors = self.taylor_factors[i]
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
output = torch.zeros_like(taylor_factors[0])
for order, factor in taylor_factors.items():
# Note: order starts at 0
coeff = (step_offset**order) / math.factorial(order)
output = output + factor * coeff
outputs.append(output.to(self.module_dtypes[i]))
return outputs
@@ -168,24 +196,18 @@ class TaylorSeerCacheHook(ModelHook):
def __init__(
self,
module_name: str,
predict_steps: int,
warmup_steps: int,
cache_interval: int,
disable_cache_before_step: int,
taylor_factors_dtype: torch.dtype,
state_manager: StateManager,
stop_predicts: Optional[int] = None,
is_skip: bool = False,
disable_cache_after_step: Optional[int] = None,
):
super().__init__()
self.module_name = module_name
self.predict_steps = predict_steps
self.warmup_steps = warmup_steps
self.stop_predicts = stop_predicts
self.cache_interval = cache_interval
self.disable_cache_before_step = disable_cache_before_step
self.disable_cache_after_step = disable_cache_after_step
self.taylor_factors_dtype = taylor_factors_dtype
self.state_manager = state_manager
self.is_skip = is_skip
self.dummy_outputs = None
def initialize_hook(self, module: torch.nn.Module):
return module
@@ -194,50 +216,48 @@ class TaylorSeerCacheHook(ModelHook):
"""
Reset state between sampling runs.
"""
self.dummy_outputs = None
self.current_step = -1
self.state_manager.reset()
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
state: TaylorSeerState = self.state_manager.get_state()
state.current_step += 1
current_step = state.current_step
is_warmup_phase = current_step < self.warmup_steps
is_warmup_phase = current_step < self.disable_cache_before_step
is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0)
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
should_compute = (
is_warmup_phase
or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0)
or (self.stop_predicts is not None and current_step >= self.stop_predicts)
or is_compute_interval
or is_cooldown_phase
)
if should_compute:
outputs = self.fn_ref.original_forward(*args, **kwargs)
if not self.is_skip:
state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step)
else:
self.dummy_outputs = outputs
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
state.update(wrapped_outputs)
return outputs
if self.is_skip:
return self.dummy_outputs
outputs = state.predict(current_step)
return outputs[0] if len(outputs) == 1 else outputs
outputs_list = state.predict()
return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list)
def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]:
"""
Resolve effective skip and cache pattern lists from config + templates.
Resolve effective inactive and active pattern lists from config + templates.
"""
skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None
cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
inactive_patterns = config.inactive_identifiers if config.inactive_identifiers is not None else None
active_patterns = config.active_identifiers if config.active_identifiers is not None else None
return skip_patterns or [], cache_patterns or []
return inactive_patterns or [], active_patterns or []
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
"""
Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
This function hooks selected modules in the model to enable caching or skipping based on the provided configuration,
reducing redundant computations in diffusion denoising loops.
Args:
module (torch.nn.Module): The model subtree to apply the hooks to.
config (TaylorSeerCacheConfig): Configuration for the cache.
@@ -254,60 +274,41 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
>>> pipe.to("cuda")
>>> config = TaylorSeerCacheConfig(
... predict_steps=5,
... cache_interval=5,
... max_order=1,
... warmup_steps=3,
... disable_cache_before_step=3,
... taylor_factors_dtype=torch.float32,
... )
>>> pipe.transformer.enable_cache(config)
```
"""
skip_patterns, cache_patterns = _resolve_patterns(config)
inactive_patterns, active_patterns = _resolve_patterns(config)
logger.debug("TaylorSeer skip identifiers: %s", skip_patterns)
logger.debug("TaylorSeer cache identifiers: %s", cache_patterns)
active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
if config.lite:
if config.use_lite_mode:
logger.info("Using TaylorSeer Lite variant for cache.")
cache_patterns = _PROJ_OUT_IDENTIFIERS
skip_patterns = _BLOCK_IDENTIFIERS
if config.skip_identifiers or config.cache_identifiers:
active_patterns = _PROJ_OUT_IDENTIFIERS
inactive_patterns = _BLOCK_IDENTIFIERS
if config.inactive_identifiers or config.active_identifiers:
logger.warning("Lite mode overrides user patterns.")
for name, submodule in module.named_modules():
matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns)
matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns)
if not (matches_skip or matches_cache):
matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns)
matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns)
if not (matches_inactive or matches_active):
continue
logger.debug(
"Applying TaylorSeer cache to %s (mode=%s)",
name,
"skip" if matches_skip else "cache",
)
state_manager = StateManager(
TaylorSeerState,
init_kwargs={
"taylor_factors_dtype": config.taylor_factors_dtype,
"max_order": config.max_order,
},
)
_apply_taylorseer_cache_hook(
name=name,
module=submodule,
config=config,
is_skip=matches_skip,
state_manager=state_manager,
is_inactive=matches_inactive,
)
def _apply_taylorseer_cache_hook(
name: str,
module: nn.Module,
config: TaylorSeerCacheConfig,
is_skip: bool,
state_manager: StateManager,
is_inactive: bool,
):
"""
Registers the TaylorSeer hook on the specified nn.Module.
@@ -316,19 +317,25 @@ def _apply_taylorseer_cache_hook(
name: Name of the module.
module: The nn.Module to be hooked.
config: Cache configuration.
is_skip: Whether this module should operate in "skip" mode.
state_manager: The state manager for managing hook state.
is_inactive: Whether this module should operate in "inactive" mode.
"""
state_manager = StateManager(
TaylorSeerState,
init_kwargs={
"taylor_factors_dtype": config.taylor_factors_dtype,
"max_order": config.max_order,
"is_inactive": is_inactive,
},
)
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = TaylorSeerCacheHook(
module_name=name,
predict_steps=config.predict_steps,
warmup_steps=config.warmup_steps,
cache_interval=config.cache_interval,
disable_cache_before_step=config.disable_cache_before_step,
taylor_factors_dtype=config.taylor_factors_dtype,
stop_predicts=config.stop_predicts,
is_skip=is_skip,
disable_cache_after_step=config.disable_cache_after_step,
state_manager=state_manager,
)
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)