mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[bug fix] dpm multistep solver duplicate timesteps
This commit is contained in:
committed by
Daniel Gu
parent
3de609dd66
commit
b8bfa562dc
@@ -192,14 +192,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = (
|
||||
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
.round()[::-1][:-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
|
||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
||||
# duplicates in timesteps.
|
||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
||||
timesteps = timesteps[np.sort(unique_indices)]
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.num_inference_steps = len(timesteps)
|
||||
|
||||
self.model_outputs = [
|
||||
None,
|
||||
] * self.config.solver_order
|
||||
|
||||
Reference in New Issue
Block a user