diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index 607d652f4a..bbca345e8b 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -157,6 +157,7 @@ class TaylorSeerState: self.last_update_step = self.current_step + @torch.compiler.disable def predict(self) -> List[torch.Tensor]: if self.last_update_step is None: raise ValueError("Cannot predict without prior initialization/update.") @@ -219,7 +220,8 @@ class TaylorSeerCacheHook(ModelHook): """ self.state_manager.reset() - def new_forward(self, module: torch.nn.Module, *args, **kwargs): + @torch.compiler.disable + def _measure_should_compute(self) -> bool: state: TaylorSeerState = self.state_manager.get_state() state.current_step += 1 current_step = state.current_step @@ -227,6 +229,10 @@ class TaylorSeerCacheHook(ModelHook): is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0 is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase + return should_compute, state + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + should_compute, state = self._measure_should_compute() if should_compute: outputs = self.fn_ref.original_forward(*args, **kwargs) wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs