1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

torch compile compatible

This commit is contained in:
toilaluan
2025-12-04 06:12:37 +00:00
parent 4fb3f53b6c
commit 76494ca098

View File

@@ -157,6 +157,7 @@ class TaylorSeerState:
self.last_update_step = self.current_step
@torch.compiler.disable
def predict(self) -> List[torch.Tensor]:
if self.last_update_step is None:
raise ValueError("Cannot predict without prior initialization/update.")
@@ -219,7 +220,8 @@ class TaylorSeerCacheHook(ModelHook):
"""
self.state_manager.reset()
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
@torch.compiler.disable
def _measure_should_compute(self) -> bool:
state: TaylorSeerState = self.state_manager.get_state()
state.current_step += 1
current_step = state.current_step
@@ -227,6 +229,10 @@ class TaylorSeerCacheHook(ModelHook):
is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
return should_compute, state
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
should_compute, state = self._measure_should_compute()
if should_compute:
outputs = self.fn_ref.original_forward(*args, **kwargs)
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs