mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix img2img speed with LMS-Discrete Scheduler (#896)
Casting `self.sigmas` into a different dtype (the one of original_samples) is not advisable. In my img2img pipeline this leads to a long running time in the `integrate.quad` call later on- by long I mean more than 10x slower. Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
@@ -243,19 +243,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = self.timesteps.to(original_samples.device)
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
schedule_timesteps = self.timesteps
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = self.sigmas[step_indices].flatten()
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user