From 0fc25715a11ef6688eff86adcc474287a9f50c1e Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Wed, 25 Oct 2023 00:34:05 -1000 Subject: [PATCH] fix a bug in 2nd order schedulers when using in ensemble of experts config (#5511) * fix * fix copies * remove heun from tests * add back heun and fix the tests to include 2nd order * fix the other test too * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * make style * add more comments --------- Co-authored-by: yiyixuxu Co-authored-by: Patrick von Platen --- .../pipeline_controlnet_inpaint_sd_xl.py | 16 +++++++++-- .../pipeline_stable_diffusion_xl_img2img.py | 16 +++++++++-- .../pipeline_stable_diffusion_xl_inpaint.py | 16 +++++++++-- .../test_stable_diffusion_xl.py | 26 +++++++++++------ .../test_stable_diffusion_xl_inpaint.py | 28 ++++++++++++------- 5 files changed, 78 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 6c5d9a3993..1423920191 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -896,8 +896,20 @@ class StableDiffusionXLControlNetInpaintPipeline( - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) - timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) - return torch.tensor(timesteps), len(timesteps) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2: + # if the scheduler is a 2nd order scheduler we ALWAYS have to do +1 + # because `num_inference_steps` will always be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 825c74ce07..ff9d669a80 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -553,8 +553,20 @@ class StableDiffusionXLImg2ImgPipeline( - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) - timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) - return torch.tensor(timesteps), len(timesteps) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2: + # if the scheduler is a 2nd order scheduler we ALWAYS have to do +1 + # because `num_inference_steps` will always be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 535cc72683..200f5a7bf4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -838,8 +838,20 @@ class StableDiffusionXLInpaintPipeline( - (denoising_start * self.scheduler.config.num_train_timesteps) ) ) - timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) - return torch.tensor(timesteps), len(timesteps) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2: + # if the scheduler is a 2nd order scheduler we ALWAYS have to do +1 + # because `num_inference_steps` will always be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps return timesteps, num_inference_steps - t_start diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 4906670890..f628ad741d 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -328,8 +328,13 @@ class StableDiffusionXLPipelineFastTests( pipe_1.scheduler.set_timesteps(num_steps) expected_steps = pipe_1.scheduler.timesteps.tolist() - expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) - expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split, expected_tss)) + expected_steps = expected_steps_1 + expected_steps_2 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split, expected_tss)) + expected_steps_2 = list(filter(lambda ts: ts < split, expected_tss)) # now we monkey patch step `done_steps` # list into the step function for testing @@ -611,13 +616,18 @@ class StableDiffusionXLPipelineFastTests( split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1)) split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2)) - expected_steps_1 = expected_steps[:split_1_ts] - expected_steps_2 = expected_steps[split_1_ts:split_2_ts] - expected_steps_3 = expected_steps[split_2_ts:] - expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) - expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) - expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = expected_steps_1[-1:] + list( + filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps) + ) + expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps)) + expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) + expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) # now we monkey patch step `done_steps` # list into the step function for testing diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 7e3698d8ca..898fda0d7b 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -318,11 +318,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel expected_steps = pipe_1.scheduler.timesteps.tolist() split_ts = num_train_timesteps - int(round(num_train_timesteps * split)) - expected_steps_1 = expected_steps[:split_ts] - expected_steps_2 = expected_steps[split_ts:] - expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps)) - expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps)) + expected_steps_2 = expected_steps_1[-1:] + list(filter(lambda ts: ts < split_ts, expected_steps)) + expected_steps = expected_steps_1 + expected_steps_2 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split_ts, expected_steps)) + expected_steps_2 = list(filter(lambda ts: ts < split_ts, expected_steps)) # now we monkey patch step `done_steps` # list into the step function for testing @@ -389,13 +392,18 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel split_1_ts = num_train_timesteps - int(round(num_train_timesteps * split_1)) split_2_ts = num_train_timesteps - int(round(num_train_timesteps * split_2)) - expected_steps_1 = expected_steps[:split_1_ts] - expected_steps_2 = expected_steps[split_1_ts:split_2_ts] - expected_steps_3 = expected_steps[split_2_ts:] - expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) - expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) - expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) + if pipe_1.scheduler.order == 2: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = expected_steps_1[-1:] + list( + filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps) + ) + expected_steps_3 = expected_steps_2[-1:] + list(filter(lambda ts: ts < split_2_ts, expected_steps)) + expected_steps = expected_steps_1 + expected_steps_2 + expected_steps_3 + else: + expected_steps_1 = list(filter(lambda ts: ts >= split_1_ts, expected_steps)) + expected_steps_2 = list(filter(lambda ts: ts >= split_2_ts and ts < split_1_ts, expected_steps)) + expected_steps_3 = list(filter(lambda ts: ts < split_2_ts, expected_steps)) # now we monkey patch step `done_steps` # list into the step function for testing