diff --git a/src/diffusers/hooks/taylorseer_cache.py b/src/diffusers/hooks/taylorseer_cache.py index e1e0bebcf0..3554255744 100644 --- a/src/diffusers/hooks/taylorseer_cache.py +++ b/src/diffusers/hooks/taylorseer_cache.py @@ -137,8 +137,7 @@ class TaylorSeerState: self.inactive_shapes = tuple(output.shape for output in outputs) else: self.taylor_factors = {} - for i, output in enumerate(outputs): - features = output.to(self.taylor_factors_dtype) + for i, features in enumerate(outputs): new_factors: Dict[int, torch.Tensor] = {0: features} is_first_update = self.last_update_step is None if not is_first_update: @@ -152,8 +151,8 @@ class TaylorSeerState: prev = prev_factors.get(j) if prev is None: break - new_factors[j + 1] = (new_factors[j] - prev.to(self.taylor_factors_dtype)) / delta_step - self.taylor_factors[i] = new_factors + new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step + self.taylor_factors[i] = {order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items()} self.last_update_step = self.current_step @@ -179,14 +178,15 @@ class TaylorSeerState: if not self.taylor_factors: raise ValueError("Taylor factors empty during prediction.") for i in range(len(self.module_dtypes)): + output_dtype = self.module_dtypes[i] taylor_factors = self.taylor_factors[i] # Accumulate Taylor series: f(t0 + Δt) ≈ Σ f^{(n)}(t0) * (Δt^n / n!) - output = torch.zeros_like(taylor_factors[0]) + output = torch.zeros_like(taylor_factors[0], dtype=output_dtype) for order, factor in taylor_factors.items(): # Note: order starts at 0 coeff = (step_offset**order) / math.factorial(order) - output = output + factor * coeff - outputs.append(output.to(self.module_dtypes[i])) + output = output + factor.to(output_dtype) * coeff + outputs.append(output) return outputs