mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
torch compile compatible
This commit is contained in:
@@ -157,6 +157,7 @@ class TaylorSeerState:
|
||||
|
||||
self.last_update_step = self.current_step
|
||||
|
||||
@torch.compiler.disable
|
||||
def predict(self) -> List[torch.Tensor]:
|
||||
if self.last_update_step is None:
|
||||
raise ValueError("Cannot predict without prior initialization/update.")
|
||||
@@ -219,7 +220,8 @@ class TaylorSeerCacheHook(ModelHook):
|
||||
"""
|
||||
self.state_manager.reset()
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
@torch.compiler.disable
|
||||
def _measure_should_compute(self) -> bool:
|
||||
state: TaylorSeerState = self.state_manager.get_state()
|
||||
state.current_step += 1
|
||||
current_step = state.current_step
|
||||
@@ -227,6 +229,10 @@ class TaylorSeerCacheHook(ModelHook):
|
||||
is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0
|
||||
is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step
|
||||
should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase
|
||||
return should_compute, state
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
should_compute, state = self._measure_should_compute()
|
||||
if should_compute:
|
||||
outputs = self.fn_ref.original_forward(*args, **kwargs)
|
||||
wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs
|
||||
|
||||
Reference in New Issue
Block a user