mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fixed noise_pred_text referenced before assignment. (#9537)
* Fixed local variable noise_pred_text referenced before assignment when using PAG with guidance scale and guidance rescale at the same time. * Fixed style. * Made returning text pred noise an argument.
This commit is contained in:
@@ -98,7 +98,9 @@ class PAGMixin:
|
||||
else:
|
||||
return self.pag_scale
|
||||
|
||||
def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t):
|
||||
def _apply_perturbed_attention_guidance(
|
||||
self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False
|
||||
):
|
||||
r"""
|
||||
Apply perturbed attention guidance to the noise prediction.
|
||||
|
||||
@@ -107,9 +109,11 @@ class PAGMixin:
|
||||
do_classifier_free_guidance (bool): Whether to apply classifier-free guidance.
|
||||
guidance_scale (float): The scale factor for the guidance term.
|
||||
t (int): The current time step.
|
||||
return_pred_text (bool): Whether to return the text noise prediction.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
|
||||
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying
|
||||
perturbed attention guidance and the text noise prediction.
|
||||
"""
|
||||
pag_scale = self._get_pag_scale(t)
|
||||
if do_classifier_free_guidance:
|
||||
@@ -122,6 +126,8 @@ class PAGMixin:
|
||||
else:
|
||||
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
if return_pred_text:
|
||||
return noise_pred, noise_pred_text
|
||||
return noise_pred
|
||||
|
||||
def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
|
||||
|
||||
@@ -893,8 +893,8 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
||||
|
||||
# perform guidance
|
||||
if self.do_perturbed_attention_guidance:
|
||||
noise_pred = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
|
||||
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
|
||||
)
|
||||
elif self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
@@ -993,8 +993,8 @@ class StableDiffusionPAGPipeline(
|
||||
|
||||
# perform guidance
|
||||
if self.do_perturbed_attention_guidance:
|
||||
noise_pred = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
|
||||
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
|
||||
)
|
||||
|
||||
elif self.do_classifier_free_guidance:
|
||||
|
||||
@@ -1237,8 +1237,8 @@ class StableDiffusionXLPAGPipeline(
|
||||
|
||||
# perform guidance
|
||||
if self.do_perturbed_attention_guidance:
|
||||
noise_pred = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
|
||||
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
|
||||
)
|
||||
|
||||
elif self.do_classifier_free_guidance:
|
||||
|
||||
@@ -1437,8 +1437,8 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
|
||||
# perform guidance
|
||||
if self.do_perturbed_attention_guidance:
|
||||
noise_pred = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
|
||||
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
|
||||
)
|
||||
elif self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
@@ -1649,8 +1649,8 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
|
||||
# perform guidance
|
||||
if self.do_perturbed_attention_guidance:
|
||||
noise_pred = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
|
||||
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
|
||||
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True
|
||||
)
|
||||
elif self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
Reference in New Issue
Block a user