mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
quality & style
This commit is contained in:
@@ -30,34 +30,35 @@ class TaylorSeerCacheConfig:
|
||||
|
||||
Attributes:
|
||||
cache_interval (`int`, defaults to `5`):
|
||||
The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused
|
||||
for this many subsequent denoising steps before refreshing with a new full forward pass.
|
||||
The interval between full computation steps. After a full computation, the cached (predicted) outputs are
|
||||
reused for this many subsequent denoising steps before refreshing with a new full forward pass.
|
||||
|
||||
disable_cache_before_step (`int`, defaults to `3`):
|
||||
The denoising step index before which caching is disabled, meaning full computation is performed for the initial
|
||||
steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps,
|
||||
Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step.
|
||||
The denoising step index before which caching is disabled, meaning full computation is performed for the
|
||||
initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During
|
||||
these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this
|
||||
step.
|
||||
|
||||
disable_cache_after_step (`int`, *optional*, defaults to `None`):
|
||||
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full
|
||||
computations without predictions or state updates, ensuring accuracy in later stages if needed.
|
||||
The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run
|
||||
full computations without predictions or state updates, ensuring accuracy in later stages if needed.
|
||||
|
||||
max_order (`int`, defaults to `1`):
|
||||
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better
|
||||
approximations but increase computation and memory usage.
|
||||
The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide
|
||||
better approximations but increase computation and memory usage.
|
||||
|
||||
taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`):
|
||||
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect
|
||||
stability; higher precision improves accuracy at the cost of more memory.
|
||||
Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may
|
||||
affect stability; higher precision improves accuracy at the cost of more memory.
|
||||
|
||||
inactive_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the module
|
||||
computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during
|
||||
prediction steps to skip computation cheaply.
|
||||
Regex patterns (using `re.fullmatch`) for module names to place in "inactive" mode. In this mode, the
|
||||
module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape)
|
||||
during prediction steps to skip computation cheaply.
|
||||
|
||||
active_identifiers (`List[str]`, *optional*, defaults to `None`):
|
||||
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs
|
||||
are approximated and cached for reuse.
|
||||
Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where
|
||||
outputs are approximated and cached for reuse.
|
||||
|
||||
use_lite_mode (`bool`, *optional*, defaults to `False`):
|
||||
Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for
|
||||
@@ -118,7 +119,6 @@ class TaylorSeerState:
|
||||
self.device: Optional[torch.device] = None
|
||||
self.current_step: int = -1
|
||||
|
||||
|
||||
def reset(self) -> None:
|
||||
self.current_step = -1
|
||||
self.last_update_step = None
|
||||
@@ -223,13 +223,9 @@ class TaylorSeerCacheHook(ModelHook):
|
||||
state.current_step += 1
|
||||
current_step = state.current_step
|
||||
is_warmup_phase = current_step < self.disable_cache_before_step
|
||||
is_compute_interval = ((current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0)
|
||||
is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0
|
||||
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
|
||||
should_compute = (
|
||||
is_warmup_phase
|
||||
or is_compute_interval
|
||||
or is_cooldown_phase
|
||||
)
|
||||
should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
|
||||
if should_compute:
|
||||
outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
|
||||
@@ -255,8 +251,8 @@ def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfi
|
||||
"""
|
||||
Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet).
|
||||
|
||||
This function hooks selected modules in the model to enable caching or skipping based on the provided configuration,
|
||||
reducing redundant computations in diffusion denoising loops.
|
||||
This function hooks selected modules in the model to enable caching or skipping based on the provided
|
||||
configuration, reducing redundant computations in diffusion denoising loops.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): The model subtree to apply the hooks to.
|
||||
@@ -338,4 +334,4 @@ def _apply_taylorseer_cache_hook(
|
||||
state_manager=state_manager,
|
||||
)
|
||||
|
||||
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
|
||||
registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)
|
||||
|
||||
Reference in New Issue
Block a user