mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update LCMScheduler Inference Timesteps to be More Evenly Spaced (#5836)
* Change LCMScheduler.set_timesteps to pick more evenly spaced inference timesteps. * Change inference_indices implementation to better match previous behavior. * Add num_inference_steps=26 test case to test_inference_steps. * run CI --------- Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
@@ -371,10 +371,11 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
# LCM Timesteps Setting
|
||||
# Currently, only linear spacing is supported.
|
||||
c = self.config.num_train_timesteps // original_steps
|
||||
# LCM Training Steps Schedule
|
||||
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
|
||||
# The skipping step parameter k from the paper.
|
||||
k = self.config.num_train_timesteps // original_steps
|
||||
# LCM Training/Distillation Steps Schedule
|
||||
# Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
|
||||
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
|
||||
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
||||
|
||||
if skipping_step < 1:
|
||||
@@ -383,9 +384,13 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
# LCM Inference Steps Schedule
|
||||
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
|
||||
lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
|
||||
# Select (approximately) evenly spaced indices from lcm_origin_timesteps.
|
||||
inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
|
||||
inference_indices = np.floor(inference_indices).astype(np.int64)
|
||||
timesteps = lcm_origin_timesteps[inference_indices]
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
|
||||
|
||||
self._step_index = None
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class LCMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
def test_inference_steps(self):
|
||||
# Hardcoded for now
|
||||
for t, num_inference_steps in zip([99, 39, 19], [10, 25, 50]):
|
||||
for t, num_inference_steps in zip([99, 39, 39, 19], [10, 25, 26, 50]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
# Override test_add_noise_device because the hardcoded num_inference_steps of 100 doesn't work
|
||||
|
||||
Reference in New Issue
Block a user