1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix taylor precision

This commit is contained in:
toilaluan
2025-11-28 08:14:41 +00:00
parent 309ce72140
commit d06c6bc6c2

View File

@@ -137,8 +137,7 @@ class TaylorSeerState:
self.inactive_shapes = tuple(output.shape for output in outputs)
else:
self.taylor_factors = {}
for i, output in enumerate(outputs):
features = output.to(self.taylor_factors_dtype)
for i, features in enumerate(outputs):
new_factors: Dict[int, torch.Tensor] = {0: features}
is_first_update = self.last_update_step is None
if not is_first_update:
@@ -152,8 +151,8 @@ class TaylorSeerState:
prev = prev_factors.get(j)
if prev is None:
break
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
self.taylor_factors[i] = new_factors
new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step
self.taylor_factors[i] = {order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()}
self.last_update_step = self.current_step
@@ -179,14 +178,15 @@ class TaylorSeerState:
if not self.taylor_factors:
raise ValueError("Taylor factors empty during prediction.")
for i in range(len(self.module_dtypes)):
output_dtype = self.module_dtypes[i]
taylor_factors = self.taylor_factors[i]
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
output = torch.zeros_like(taylor_factors[0])
output = torch.zeros_like(taylor_factors[0], dtype=output_dtype)
for order, factor in taylor_factors.items():
# Note: order starts at 0
coeff = (step_offset**order) / math.factorial(order)
output = output + factor * coeff
outputs.append(output.to(self.module_dtypes[i]))
output = output + factor.to(output_dtype) * coeff
outputs.append(output)
return outputs