From 86bd991ee5c9c669e22f09693d68b60d0ec59dd1 Mon Sep 17 00:00:00 2001 From: v2ray <60914079+LagPixelLOL@users.noreply.github.com> Date: Wed, 9 Oct 2024 03:27:10 +0800 Subject: [PATCH] Fixed noise_pred_text referenced before assignment. (#9537) * Fixed local variable noise_pred_text referenced before assignment when using PAG with guidance scale and guidance rescale at the same time. * Fixed style. * Made returning text pred noise an argument. --- src/diffusers/pipelines/pag/pag_utils.py | 10 ++++++++-- src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py | 4 ++-- src/diffusers/pipelines/pag/pipeline_pag_sd.py | 4 ++-- src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py | 4 ++-- .../pipelines/pag/pipeline_pag_sd_xl_img2img.py | 4 ++-- .../pipelines/pag/pipeline_pag_sd_xl_inpaint.py | 4 ++-- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 728f730c99..7a6e30a3c6 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -98,7 +98,9 @@ class PAGMixin: else: return self.pag_scale - def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + def _apply_perturbed_attention_guidance( + self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False + ): r""" Apply perturbed attention guidance to the noise prediction. @@ -107,9 +109,11 @@ class PAGMixin: do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. guidance_scale (float): The scale factor for the guidance term. t (int): The current time step. + return_pred_text (bool): Whether to return the text noise prediction. Returns: - torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying + perturbed attention guidance and the text noise prediction. """ pag_scale = self._get_pag_scale(t) if do_classifier_free_guidance: @@ -122,6 +126,8 @@ class PAGMixin: else: noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + if return_pred_text: + return noise_pred, noise_pred_text return noise_pred def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 63126cc5aa..4663db3a15 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -893,8 +893,8 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py index c6a4f7f42c..e9742b08af 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -993,8 +993,8 @@ class StableDiffusionPAGPipeline( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 18fc06c1f9..8da4349594 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -1237,8 +1237,8 @@ class StableDiffusionXLPAGPipeline( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index dc85aaaca3..4c2c4e5aa3 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -1437,8 +1437,8 @@ class StableDiffusionXLPAGImg2ImgPipeline( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index f5ebf43009..49e4c5ffd5 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -1649,8 +1649,8 @@ class StableDiffusionXLPAGInpaintPipeline( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)