diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 89d4278937..cbb6ed4fa1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -323,8 +323,6 @@ class StableDiffusionInstructPix2PixPipeline( batch_size = prompt_embeds.shape[0] device = self._execution_device - # check if scheduler is in sigmas space - scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 2. Encode input prompt prompt_embeds = self._encode_prompt( @@ -411,15 +409,6 @@ class StableDiffusionInstructPix2PixPipeline( return_dict=False, )[0] - # Hack: - # For karras style schedulers the model does classifer free guidance using the - # predicted_original_sample instead of the noise_pred. So we need to compute the - # predicted_original_sample here if we are using a karras style scheduler. - if scheduler_is_in_sigma_space: - step_index = (self.scheduler.timesteps == t).nonzero()[0].item() - sigma = self.scheduler.sigmas[step_index] - noise_pred = latent_model_input - sigma * noise_pred - # perform guidance if self.do_classifier_free_guidance: noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) @@ -429,15 +418,6 @@ class StableDiffusionInstructPix2PixPipeline( + self.image_guidance_scale * (noise_pred_image - noise_pred_uncond) ) - # Hack: - # For karras style schedulers the model does classifer free guidance using the - # predicted_original_sample instead of the noise_pred. But the scheduler.step function - # expects the noise_pred and computes the predicted_original_sample internally. So we - # need to overwrite the noise_pred here such that the value of the computed - # predicted_original_sample is correct. - if scheduler_is_in_sigma_space: - noise_pred = (noise_pred - latents) / (-sigma) - # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 8bf07d2f80..51e413d4b5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -774,8 +774,6 @@ class StableDiffusionXLInstructPix2PixPipeline( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0 - # check if scheduler is in sigmas space - scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") # 3. Encode input prompt text_encoder_lora_scale = ( @@ -906,15 +904,6 @@ class StableDiffusionXLInstructPix2PixPipeline( return_dict=False, )[0] - # Hack: - # For karras style schedulers the model does classifer free guidance using the - # predicted_original_sample instead of the noise_pred. So we need to compute the - # predicted_original_sample here if we are using a karras style scheduler. - if scheduler_is_in_sigma_space: - step_index = (self.scheduler.timesteps == t).nonzero()[0].item() - sigma = self.scheduler.sigmas[step_index] - noise_pred = latent_model_input - sigma * noise_pred - # perform guidance if do_classifier_free_guidance: noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3) @@ -928,15 +917,6 @@ class StableDiffusionXLInstructPix2PixPipeline( # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - # Hack: - # For karras style schedulers the model does classifer free guidance using the - # predicted_original_sample instead of the noise_pred. But the scheduler.step function - # expects the noise_pred and computes the predicted_original_sample internally. So we - # need to overwrite the noise_pred here such that the value of the computed - # predicted_original_sample is correct. - if scheduler_is_in_sigma_space: - noise_pred = (noise_pred - latents) / (-sigma) - # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]