1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fixed noise_pred_text referenced before assignment. (#9537)

* Fixed local variable noise_pred_text referenced before assignment when using PAG with guidance scale and guidance rescale at the same time.

* Fixed style.

* Made returning text pred noise an argument.
This commit is contained in:
v2ray
2024-10-09 03:27:10 +08:00
committed by GitHub
parent 02eeb8e77e
commit 86bd991ee5
6 changed files with 18 additions and 12 deletions

View File

@@ -98,7 +98,9 @@ class PAGMixin:
else:
return self.pag_scale
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
def _apply_perturbed_attention_guidance(
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
):
r"""
Apply perturbed attention guidance to the noise prediction.
@@ -107,9 +109,11 @@ class PAGMixin:
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
guidance_scale (float): The scale factor for the guidance term.
t (int): The current time step.
return_pred_text (bool): Whether to return the text noise prediction.
Returns:
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
perturbed attention guidance and the text noise prediction.
"""
pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance:
@@ -122,6 +126,8 @@ class PAGMixin:
else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
if return_pred_text:
return noise_pred, noise_pred_text
return noise_pred
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):

View File

@@ -893,8 +893,8 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

View File

@@ -993,8 +993,8 @@ class StableDiffusionPAGPipeline(
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:

View File

@@ -1237,8 +1237,8 @@ class StableDiffusionXLPAGPipeline(
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:

View File

@@ -1437,8 +1437,8 @@ class StableDiffusionXLPAGImg2ImgPipeline(
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

View File

@@ -1649,8 +1649,8 @@ class StableDiffusionXLPAGInpaintPipeline(
# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
)
elif self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)