mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
chores: naming, remove redundancy
This commit is contained in:
@@ -29,64 +29,74 @@ class TaylorSeerCacheConfig:
|
||||
Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923
|
||||
|
||||
Attributes:
|
||||
warmup_steps (`int`, defaults to `3`):
|
||||
Number of denoising steps to run with full computation before enabling caching. During warmup, the Taylor
|
||||
series factors are still updated, but no predictions are used.
|
||||
cache_interval (`int`, defaults to `5`):
|
||||
The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused
|
||||
for this many subsequent denoising steps before refreshing with a new full forward pass.
|
||||
|
||||
predict_steps (`int`, defaults to `5`):
|
||||
Number of prediction (cached) steps to take between two full computations. That is, once a module state is
|
||||
refreshed, it will be reused for `predict_steps` subsequent denoising steps, then a new full forward will
|
||||
be computed on the next step.
|
||||
disable_cache_before_step (`int`, defaults to `3`):
|
||||
The denoising step index before which caching is disabled, meaning full computation is performed for the initial
|
||||
steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps,
|
||||
Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step.
|
||||
|
||||
stop_predicts (`int`, *optional*, defaults to `None`):
|
||||
Denoising step index at which caching is disabled. If provided, for `self.current_step >= stop_predicts`
|
||||
all modules are evaluated normally (no predictions, no state updates).
|
||||
disable_cache_after_step (`int`, *optional*, defaults to `None`):
|
||||
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full
|
||||
computations without predictions or state updates, ensuring accuracy in later stages if needed.
|
||||
|
||||
max_order (`int`, defaults to `1`):
|
||||
Maximum order of Taylor series expansion to approximate the features. Higher order gives closer
|
||||
approximation but more compute.
|
||||
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better
|
||||
approximations but increase computation and memory usage.
|
||||
|
||||
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
||||
Data type for computing Taylor series expansion factors. Use lower precision to reduce memory usage. Use
|
||||
higher precision to improve numerical stability.
|
||||
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect
|
||||
stability; higher precision improves accuracy at the cost of more memory.
|
||||
|
||||
skip_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (fullmatch) for module names to be placed in "skip" mode, where the module is evaluated
|
||||
during warmup / refresh, but then replaced by a cheap dummy tensor during prediction steps.
|
||||
inactive_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module
|
||||
computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during
|
||||
prediction steps to skip computation cheaply.
|
||||
|
||||
cache_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (fullmatch) for module names to be placed in Taylor-series caching mode.
|
||||
active_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs
|
||||
are approximated and cached for reuse.
|
||||
|
||||
use_lite_mode (`bool`, *optional*, defaults to `False`):
|
||||
Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
|
||||
skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom
|
||||
`inactive_identifiers` or `active_identifiers`.
|
||||
|
||||
lite (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a TaylorSeer Lite variant that reduces memory usage. This option overrides any user-provided
|
||||
`skip_identifiers` or `cache_identifiers` patterns.
|
||||
Notes:
|
||||
- Patterns are applied with `re.fullmatch` on `module_name`.
|
||||
- If either `skip_identifiers` or `cache_identifiers` is provided, only modules matching at least one of those
|
||||
patterns will be hooked.
|
||||
- If neither is provided, all attention-like modules will be hooked.
|
||||
- Patterns are matched using `re.fullmatch` on the module name.
|
||||
- If `inactive_identifiers` or `active_identifiers` are provided, only matching modules are hooked.
|
||||
- If neither is provided, all attention-like modules are hooked by default.
|
||||
- Example of inactive and active usage:
|
||||
```
|
||||
def forward(x):
|
||||
x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute
|
||||
x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps
|
||||
return x
|
||||
```
|
||||
"""
|
||||
|
||||
warmup_steps: int = 3
|
||||
predict_steps: int = 5
|
||||
stop_predicts: Optional[int] = None
|
||||
cache_interval: int = 5
|
||||
disable_cache_before_step: int = 3
|
||||
disable_cache_after_step: Optional[int] = None
|
||||
max_order: int = 1
|
||||
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16
|
||||
skip_identifiers: Optional[List[str]] = None
|
||||
cache_identifiers: Optional[List[str]] = None
|
||||
lite: bool = False
|
||||
inactive_identifiers: Optional[List[str]] = None
|
||||
active_identifiers: Optional[List[str]] = None
|
||||
use_lite_mode: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
"TaylorSeerCacheConfig("
|
||||
f"warmup_steps={self.warmup_steps}, "
|
||||
f"predict_steps={self.predict_steps}, "
|
||||
f"stop_predicts={self.stop_predicts}, "
|
||||
f"cache_interval={self.cache_interval}, "
|
||||
f"disable_cache_before_step={self.disable_cache_before_step}, "
|
||||
f"disable_cache_after_step={self.disable_cache_after_step}, "
|
||||
f"max_order={self.max_order}, "
|
||||
f"taylor_factors_dtype={self.taylor_factors_dtype}, "
|
||||
f"skip_identifiers={self.skip_identifiers}, "
|
||||
f"cache_identifiers={self.cache_identifiers}, "
|
||||
f"lite={self.lite})"
|
||||
f"inactive_identifiers={self.inactive_identifiers}, "
|
||||
f"active_identifiers={self.active_identifiers}, "
|
||||
f"use_lite_mode={self.use_lite_mode})"
|
||||
)
|
||||
|
||||
|
||||
@@ -95,70 +105,88 @@ class TaylorSeerState:
|
||||
self,
|
||||
taylor_factors_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
max_order: int = 1,
|
||||
is_inactive: bool = False,
|
||||
):
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.max_order = max_order
|
||||
self.is_inactive = is_inactive
|
||||
|
||||
self.module_dtypes: Tuple[torch.dtype, ...] = ()
|
||||
self.last_update_step: Optional[int] = None
|
||||
self.taylor_factors: Dict[int, Dict[int, torch.Tensor]] = {}
|
||||
|
||||
# For skip-mode modules
|
||||
self.inactive_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None
|
||||
self.device: Optional[torch.device] = None
|
||||
self.dummy_tensors: Optional[Tuple[torch.Tensor, ...]] = None
|
||||
self.current_step: int = -1
|
||||
|
||||
self.current_step = -1
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = -1
|
||||
self.last_update_step = None
|
||||
self.taylor_factors = {}
|
||||
self.inactive_shapes = None
|
||||
self.device = None
|
||||
self.dummy_tensors = None
|
||||
self.current_step = -1
|
||||
|
||||
def update(
|
||||
self,
|
||||
outputs: Tuple[torch.Tensor, ...],
|
||||
current_step: int,
|
||||
) -> None:
|
||||
self.module_dtypes = tuple(output.dtype for output in outputs)
|
||||
for i in range(len(outputs)):
|
||||
features = outputs[i].to(self.taylor_factors_dtype)
|
||||
new_factors: Dict[int, torch.Tensor] = {0: features}
|
||||
is_first_update = self.last_update_step is None
|
||||
if not is_first_update:
|
||||
delta_step = current_step - self.last_update_step
|
||||
if delta_step == 0:
|
||||
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
|
||||
self.device = outputs[0].device
|
||||
|
||||
# Recursive divided differences up to max_order
|
||||
for j in range(self.max_order):
|
||||
prev = self.taylor_factors[i].get(j)
|
||||
if prev is None:
|
||||
break
|
||||
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
|
||||
self.taylor_factors[i] = new_factors
|
||||
self.last_update_step = current_step
|
||||
if self.is_inactive:
|
||||
self.inactive_shapes = tuple(output.shape for output in outputs)
|
||||
else:
|
||||
self.taylor_factors = {}
|
||||
for i, output in enumerate(outputs):
|
||||
features = output.to(self.taylor_factors_dtype)
|
||||
new_factors: Dict[int, torch.Tensor] = {0: features}
|
||||
is_first_update = self.last_update_step is None
|
||||
if not is_first_update:
|
||||
delta_step = self.current_step - self.last_update_step
|
||||
if delta_step == 0:
|
||||
raise ValueError("Delta step cannot be zero for TaylorSeer update.")
|
||||
|
||||
def predict(self, current_step: int) -> torch.Tensor:
|
||||
# Recursive divided differences up to max_order
|
||||
prev_factors = self.taylor_factors.get(i, {})
|
||||
for j in range(self.max_order):
|
||||
prev = prev_factors.get(j)
|
||||
if prev is None:
|
||||
break
|
||||
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
|
||||
self.taylor_factors[i] = new_factors
|
||||
|
||||
self.last_update_step = self.current_step
|
||||
|
||||
def predict(self) -> List[torch.Tensor]:
|
||||
if self.last_update_step is None:
|
||||
raise ValueError("Cannot predict without prior initialization/update.")
|
||||
|
||||
step_offset = current_step - self.last_update_step
|
||||
|
||||
if not self.taylor_factors:
|
||||
raise ValueError("Taylor factors empty during prediction.")
|
||||
step_offset = self.current_step - self.last_update_step
|
||||
|
||||
outputs = []
|
||||
for i in range(len(self.module_dtypes)):
|
||||
taylor_factors = self.taylor_factors[i]
|
||||
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
|
||||
output = torch.zeros_like(taylor_factors[0])
|
||||
for order, factor in taylor_factors.items():
|
||||
# Note: order starts at 0
|
||||
coeff = (step_offset**order) / math.factorial(order)
|
||||
output = output + factor * coeff
|
||||
outputs.append(output.to(self.module_dtypes[i]))
|
||||
if self.is_inactive:
|
||||
if self.inactive_shapes is None:
|
||||
raise ValueError("Inactive shapes not set during prediction.")
|
||||
for i in range(len(self.module_dtypes)):
|
||||
outputs.append(
|
||||
torch.zeros(
|
||||
self.inactive_shapes[i],
|
||||
dtype=self.module_dtypes[i],
|
||||
device=self.device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if not self.taylor_factors:
|
||||
raise ValueError("Taylor factors empty during prediction.")
|
||||
for i in range(len(self.module_dtypes)):
|
||||
taylor_factors = self.taylor_factors[i]
|
||||
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
|
||||
output = torch.zeros_like(taylor_factors[0])
|
||||
for order, factor in taylor_factors.items():
|
||||
# Note: order starts at 0
|
||||
coeff = (step_offset**order) / math.factorial(order)
|
||||
output = output + factor * coeff
|
||||
outputs.append(output.to(self.module_dtypes[i]))
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -168,24 +196,18 @@ class TaylorSeerCacheHook(ModelHook):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
predict_steps: int,
|
||||
warmup_steps: int,
|
||||
cache_interval: int,
|
||||
disable_cache_before_step: int,
|
||||
taylor_factors_dtype: torch.dtype,
|
||||
state_manager: StateManager,
|
||||
stop_predicts: Optional[int] = None,
|
||||
is_skip: bool = False,
|
||||
disable_cache_after_step: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.module_name = module_name
|
||||
self.predict_steps = predict_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.stop_predicts = stop_predicts
|
||||
self.cache_interval = cache_interval
|
||||
self.disable_cache_before_step = disable_cache_before_step
|
||||
self.disable_cache_after_step = disable_cache_after_step
|
||||
self.taylor_factors_dtype = taylor_factors_dtype
|
||||
self.state_manager = state_manager
|
||||
self.is_skip = is_skip
|
||||
|
||||
self.dummy_outputs = None
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module):
|
||||
return module
|
||||
@@ -194,50 +216,48 @@ class TaylorSeerCacheHook(ModelHook):
|
||||
"""
|
||||
Reset state between sampling runs.
|
||||
"""
|
||||
self.dummy_outputs = None
|
||||
self.current_step = -1
|
||||
self.state_manager.reset()
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
state: TaylorSeerState = self.state_manager.get_state()
|
||||
state.current_step += 1
|
||||
current_step = state.current_step
|
||||
is_warmup_phase = current_step < self.warmup_steps
|
||||
is_warmup_phase = current_step < self.disable_cache_before_step
|
||||
is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0)
|
||||
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
|
||||
should_compute = (
|
||||
is_warmup_phase
|
||||
or ((current_step - self.warmup_steps - 1) % self.predict_steps == 0)
|
||||
or (self.stop_predicts is not None and current_step >= self.stop_predicts)
|
||||
or is_compute_interval
|
||||
or is_cooldown_phase
|
||||
)
|
||||
if should_compute:
|
||||
outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if not self.is_skip:
|
||||
state.update((outputs,) if isinstance(outputs, torch.Tensor) else outputs, current_step)
|
||||
else:
|
||||
self.dummy_outputs = outputs
|
||||
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
|
||||
state.update(wrapped_outputs)
|
||||
return outputs
|
||||
|
||||
if self.is_skip:
|
||||
return self.dummy_outputs
|
||||
|
||||
outputs = state.predict(current_step)
|
||||
return outputs[0] if len(outputs) == 1 else outputs
|
||||
outputs_list = state.predict()
|
||||
return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list)
|
||||
|
||||
|
||||
def _resolve_patterns(config: TaylorSeerCacheConfig) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Resolve effective skip and cache pattern lists from config + templates.
|
||||
Resolve effective inactive and active pattern lists from config + templates.
|
||||
"""
|
||||
|
||||
skip_patterns = config.skip_identifiers if config.skip_identifiers is not None else None
|
||||
cache_patterns = config.cache_identifiers if config.cache_identifiers is not None else None
|
||||
inactive_patterns = config.inactive_identifiers if config.inactive_identifiers is not None else None
|
||||
active_patterns = config.active_identifiers if config.active_identifiers is not None else None
|
||||
|
||||
return skip_patterns or [], cache_patterns or []
|
||||
return inactive_patterns or [], active_patterns or []
|
||||
|
||||
|
||||
def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig):
|
||||
"""
|
||||
Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
|
||||
|
||||
This function hooks selected modules in the model to enable caching or skipping based on the provided configuration,
|
||||
reducing redundant computations in diffusion denoising loops.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): The model subtree to apply the hooks to.
|
||||
config (TaylorSeerCacheConfig): Configuration for the cache.
|
||||
@@ -254,60 +274,41 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> config = TaylorSeerCacheConfig(
|
||||
... predict_steps=5,
|
||||
... cache_interval=5,
|
||||
... max_order=1,
|
||||
... warmup_steps=3,
|
||||
... disable_cache_before_step=3,
|
||||
... taylor_factors_dtype=torch.float32,
|
||||
... )
|
||||
>>> pipe.transformer.enable_cache(config)
|
||||
```
|
||||
"""
|
||||
skip_patterns, cache_patterns = _resolve_patterns(config)
|
||||
inactive_patterns, active_patterns = _resolve_patterns(config)
|
||||
|
||||
logger.debug("TaylorSeer skip identifiers: %s", skip_patterns)
|
||||
logger.debug("TaylorSeer cache identifiers: %s", cache_patterns)
|
||||
active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
cache_patterns = cache_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
|
||||
if config.lite:
|
||||
if config.use_lite_mode:
|
||||
logger.info("Using TaylorSeer Lite variant for cache.")
|
||||
cache_patterns = _PROJ_OUT_IDENTIFIERS
|
||||
skip_patterns = _BLOCK_IDENTIFIERS
|
||||
if config.skip_identifiers or config.cache_identifiers:
|
||||
active_patterns = _PROJ_OUT_IDENTIFIERS
|
||||
inactive_patterns = _BLOCK_IDENTIFIERS
|
||||
if config.inactive_identifiers or config.active_identifiers:
|
||||
logger.warning("Lite mode overrides user patterns.")
|
||||
|
||||
for name, submodule in module.named_modules():
|
||||
matches_skip = any(re.fullmatch(pattern, name) for pattern in skip_patterns)
|
||||
matches_cache = any(re.fullmatch(pattern, name) for pattern in cache_patterns)
|
||||
if not (matches_skip or matches_cache):
|
||||
matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns)
|
||||
matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns)
|
||||
if not (matches_inactive or matches_active):
|
||||
continue
|
||||
logger.debug(
|
||||
"Applying TaylorSeer cache to %s (mode=%s)",
|
||||
name,
|
||||
"skip" if matches_skip else "cache",
|
||||
)
|
||||
state_manager = StateManager(
|
||||
TaylorSeerState,
|
||||
init_kwargs={
|
||||
"taylor_factors_dtype": config.taylor_factors_dtype,
|
||||
"max_order": config.max_order,
|
||||
},
|
||||
)
|
||||
_apply_taylorseer_cache_hook(
|
||||
name=name,
|
||||
module=submodule,
|
||||
config=config,
|
||||
is_skip=matches_skip,
|
||||
state_manager=state_manager,
|
||||
is_inactive=matches_inactive,
|
||||
)
|
||||
|
||||
|
||||
def _apply_taylorseer_cache_hook(
|
||||
name: str,
|
||||
module: nn.Module,
|
||||
config: TaylorSeerCacheConfig,
|
||||
is_skip: bool,
|
||||
state_manager: StateManager,
|
||||
is_inactive: bool,
|
||||
):
|
||||
"""
|
||||
Registers the TaylorSeer hook on the specified nn.Module.
|
||||
@@ -316,19 +317,25 @@ def _apply_taylorseer_cache_hook(
|
||||
name: Name of the module.
|
||||
module: The nn.Module to be hooked.
|
||||
config: Cache configuration.
|
||||
is_skip: Whether this module should operate in "skip" mode.
|
||||
state_manager: The state manager for managing hook state.
|
||||
is_inactive: Whether this module should operate in "inactive" mode.
|
||||
"""
|
||||
state_manager = StateManager(
|
||||
TaylorSeerState,
|
||||
init_kwargs={
|
||||
"taylor_factors_dtype": config.taylor_factors_dtype,
|
||||
"max_order": config.max_order,
|
||||
"is_inactive": is_inactive,
|
||||
},
|
||||
)
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
hook = TaylorSeerCacheHook(
|
||||
module_name=name,
|
||||
predict_steps=config.predict_steps,
|
||||
warmup_steps=config.warmup_steps,
|
||||
cache_interval=config.cache_interval,
|
||||
disable_cache_before_step=config.disable_cache_before_step,
|
||||
taylor_factors_dtype=config.taylor_factors_dtype,
|
||||
stop_predicts=config.stop_predicts,
|
||||
is_skip=is_skip,
|
||||
disable_cache_after_step=config.disable_cache_after_step,
|
||||
state_manager=state_manager,
|
||||
)
|
||||
|
||||
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
|
||||
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
|
||||
Reference in New Issue
Block a user