1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update LCMScheduler Inference Timesteps to be More Evenly Spaced (#5836)

* Change LCMScheduler.set_timesteps to pick more evenly spaced inference timesteps.

* Change inference_indices implementation to better match previous behavior.

* Add num_inference_steps=26 test case to test_inference_steps.

* run CI

---------

Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
dg845
2023-11-20 06:46:10 -08:00
committed by GitHub
parent 3303aec5f8
commit dc21498b43
2 changed files with 12 additions and 7 deletions

View File

@@ -371,10 +371,11 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
)
# LCM Timesteps Setting
# Currently, only linear spacing is supported.
c = self.config.num_train_timesteps // original_steps
# LCM Training Steps Schedule
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
# The skipping step parameter k from the paper.
k = self.config.num_train_timesteps // original_steps
# LCM Training/Distillation Steps Schedule
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
if skipping_step < 1:
@@ -383,9 +384,13 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
)
# LCM Inference Steps Schedule
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
timesteps = lcm_origin_timesteps[inference_indices]
self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long)
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
self._step_index = None

View File

@@ -84,7 +84,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
def test_inference_steps(self):
# Hardcoded for now
for t, num_inference_steps in zip([99, 39, 19], [10, 25, 50]):
for t, num_inference_steps in zip([99, 39, 39, 19], [10, 25, 26, 50]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
# Override test_add_noise_device because the hardcoded num_inference_steps of 100 doesn't work