mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix taylor precision
This commit is contained in:
@@ -137,8 +137,7 @@ class TaylorSeerState:
|
||||
self.inactive_shapes = tuple(output.shape for output in outputs)
|
||||
else:
|
||||
self.taylor_factors = {}
|
||||
for i, output in enumerate(outputs):
|
||||
features = output.to(self.taylor_factors_dtype)
|
||||
for i, features in enumerate(outputs):
|
||||
new_factors: Dict[int, torch.Tensor] = {0: features}
|
||||
is_first_update = self.last_update_step is None
|
||||
if not is_first_update:
|
||||
@@ -152,8 +151,8 @@ class TaylorSeerState:
|
||||
prev = prev_factors.get(j)
|
||||
if prev is None:
|
||||
break
|
||||
new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step
|
||||
self.taylor_factors[i] = new_factors
|
||||
new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step
|
||||
self.taylor_factors[i] = {order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()}
|
||||
|
||||
self.last_update_step = self.current_step
|
||||
|
||||
@@ -179,14 +178,15 @@ class TaylorSeerState:
|
||||
if not self.taylor_factors:
|
||||
raise ValueError("Taylor factors empty during prediction.")
|
||||
for i in range(len(self.module_dtypes)):
|
||||
output_dtype = self.module_dtypes[i]
|
||||
taylor_factors = self.taylor_factors[i]
|
||||
# Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!)
|
||||
output = torch.zeros_like(taylor_factors[0])
|
||||
output = torch.zeros_like(taylor_factors[0], dtype=output_dtype)
|
||||
for order, factor in taylor_factors.items():
|
||||
# Note: order starts at 0
|
||||
coeff = (step_offset**order) / math.factorial(order)
|
||||
output = output + factor * coeff
|
||||
outputs.append(output.to(self.module_dtypes[i]))
|
||||
output = output + factor.to(output_dtype) * coeff
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user