From 63b684684979b22d9cfbeacc534c2f274e95e3ab Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 19 Mar 2024 00:50:58 -1000 Subject: [PATCH] [scheduler] fix a bug in add_noise (#7386) * fix * fix * add a tests * fix --------- Co-authored-by: Sayak Paul Co-authored-by: yiyixuxu --- ...ipeline_stable_diffusion_inpaint_legacy.py | 3 --- .../pipeline_stable_diffusion_diffedit.py | 3 --- .../scheduling_consistency_models.py | 4 ++++ .../schedulers/scheduling_deis_multistep.py | 6 ++++- .../scheduling_dpmsolver_multistep.py | 6 ++++- .../schedulers/scheduling_dpmsolver_sde.py | 4 ++++ .../scheduling_dpmsolver_singlestep.py | 6 ++++- .../scheduling_edm_dpmsolver_multistep.py | 4 ++++ .../schedulers/scheduling_edm_euler.py | 4 ++++ .../scheduling_euler_ancestral_discrete.py | 4 ++++ .../schedulers/scheduling_euler_discrete.py | 4 ++++ .../schedulers/scheduling_heun_discrete.py | 4 ++++ .../scheduling_k_dpm_2_ancestral_discrete.py | 4 ++++ .../schedulers/scheduling_k_dpm_2_discrete.py | 4 ++++ .../schedulers/scheduling_lms_discrete.py | 4 ++++ .../schedulers/scheduling_unipc_multistep.py | 6 ++++- .../test_stable_diffusion_inpaint.py | 24 +++++++++++++++++++ 17 files changed, 84 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py index 980adf2737..c7dff9eeef 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -528,15 +528,12 @@ class StableDiffusionInpaintPipelineLegacy( f" {negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 9bb68c1d3e..206c3436bb 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -716,15 +716,12 @@ class StableDiffusionDiffEditPipeline( f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - if hasattr(self.scheduler, "set_begin_index"): - self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index 5a37886e22..316d657e8d 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -434,7 +434,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index dfbb8af82d..27213c0af1 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -768,10 +768,14 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 4780a34135..1c8ec1090b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -1011,10 +1011,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 96962c315e..22aeba5750 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -543,7 +543,11 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index f0a417f83c..3113d61a94 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -961,10 +961,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index e98f37ca6a..f490da07ab 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -669,7 +669,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py index a005b4f20c..7f0ef90294 100644 --- a/src/diffusers/schedulers/scheduling_edm_euler.py +++ b/src/diffusers/schedulers/scheduling_edm_euler.py @@ -367,7 +367,11 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 01c2024148..81009b4fe7 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -467,7 +467,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index afe2d1456e..62d0caa4e2 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -562,7 +562,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 2de5a2913c..14ab01da74 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -468,7 +468,11 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 6390a187f2..af50fcf54f 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -494,7 +494,11 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 32638000d6..888ba311a9 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -469,7 +469,11 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 9f759683b5..c8abade3e5 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -461,7 +461,11 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index e3f0db8494..a844bcaec3 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -862,10 +862,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - # begin_index is None when the scheduler is used for training + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] else: + # add noise is called bevore first denoising step to create inital latent(img2img) step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index e4e97d7bfc..dec62e7e46 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -29,6 +29,7 @@ from diffusers import ( AutoencoderKL, DDIMScheduler, DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, LCMScheduler, LMSDiscreteScheduler, PNDMScheduler, @@ -557,6 +558,29 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli image_slice2 = images[1, -3:, -3:, -1] assert np.abs(image_slice1.flatten() - image_slice2.flatten()).max() > 1e-2 + def test_stable_diffusion_inpaint_euler(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device, output_pil=False) + half_dim = inputs["image"].shape[2] // 2 + inputs["mask_image"][0, 0, :half_dim, :half_dim] = 0 + + inputs["num_inference_steps"] = 4 + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array( + [[0.6387283, 0.5564158, 0.58631873, 0.5539942, 0.5494673, 0.6461868, 0.5251618, 0.5497595, 0.5508756]] + ) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4 + @slow @require_torch_gpu