From aa2ce41b99a7990a8eb03bf2bf9253a40909b31e Mon Sep 17 00:00:00 2001 From: NotNANtoN Date: Fri, 18 Nov 2022 16:01:57 +0100 Subject: [PATCH] Fix img2img speed with LMS-Discrete Scheduler (#896) Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the `integrate.quad` call later on- by long I mean more than 10x slower. Co-authored-by: Anton Lozhkov --- src/diffusers/schedulers/scheduling_lms_discrete.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 8a9aedb41b..cc9e8d7256 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -243,19 +243,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: - self.timesteps = self.timesteps.to(original_samples.device) + schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - schedule_timesteps = self.timesteps step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - sigma = self.sigmas[step_indices].flatten() + sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1)