1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix a bug in 2nd order schedulers when using in ensemble of experts config (#5511)

* fix

* fix copies

* remove heun from tests

* add back heun and fix the tests to include 2nd order

* fix the other test too

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* make style

* add more comments

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
YiYi Xu
2023-10-25 00:34:05 -10:00
committed by GitHub
parent de71fa59f5
commit 0fc25715a1
5 changed files with 78 additions and 24 deletions

View File

@@ -896,8 +896,20 @@ class StableDiffusionXLControlNetInpaintPipeline(
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start

View File

@@ -553,8 +553,20 @@ class StableDiffusionXLImg2ImgPipeline(
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start

View File

@@ -838,8 +838,20 @@ class StableDiffusionXLInpaintPipeline(
- (denoising_start * self.scheduler.config.num_train_timesteps)
)
)
timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
return torch.tensor(timesteps), len(timesteps)
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if self.scheduler.order == 2:
# if the scheduler is a 2nd order scheduler we ALWAYS have to do +1
# because `num_inference_steps` will always be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start

View File

@@ -328,8 +328,13 @@ class StableDiffusionXLPipelineFastTests(
pipe_1.scheduler.set_timesteps(num_steps)
expected_steps = pipe_1.scheduler.timesteps.tolist()
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss))
expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss))
# now we monkey patch step `done_steps`
# list into the step function for testing
@@ -611,13 +616,18 @@ class StableDiffusionXLPipelineFastTests(
split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps`
# list into the step function for testing

View File

@@ -318,11 +318,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
expected_steps = pipe_1.scheduler.timesteps.tolist()
split_ts = num_train_timesteps - int(round(num_train_timesteps * split))
expected_steps_1 = expected_steps[:split_ts]
expected_steps_2 = expected_steps[split_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps))
# now we monkey patch step `done_steps`
# list into the step function for testing
@@ -389,13 +392,18 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1))
split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2))
expected_steps_1 = expected_steps[:split_1_ts]
expected_steps_2 = expected_steps[split_1_ts:split_2_ts]
expected_steps_3 = expected_steps[split_2_ts:]
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
if pipe_1.scheduler.order == 2:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = expected_steps_1[-1:] + list(
filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)
)
expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps))
expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3
else:
expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps))
expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps))
expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps))
# now we monkey patch step `done_steps`
# list into the step function for testing