mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Modification on the PAG community pipeline (re) (#7876)
* edited_pag_implementation * update --------- Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# Implementation of StableDiffusionPAGPipeline
|
||||
# Implementation of StableDiffusionPipeline with PAG
|
||||
# https://ku-cvlab.github.io/Perturbed-Attention-Guidance
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
@@ -134,8 +135,8 @@ class PAGIdentitySelfAttnProcessor:
|
||||
|
||||
value = attn.to_v(hidden_states_ptb)
|
||||
|
||||
hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
|
||||
# hidden_states_ptb = value
|
||||
# hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
|
||||
hidden_states_ptb = value
|
||||
|
||||
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
||||
|
||||
@@ -1045,7 +1046,7 @@ class StableDiffusionPAGPipeline(
|
||||
return self._pag_scale
|
||||
|
||||
@property
|
||||
def do_adversarial_guidance(self):
|
||||
def do_perturbed_attention_guidance(self):
|
||||
return self._pag_scale > 0
|
||||
|
||||
@property
|
||||
@@ -1056,14 +1057,6 @@ class StableDiffusionPAGPipeline(
|
||||
def do_pag_adaptive_scaling(self):
|
||||
return self._pag_adaptive_scaling > 0
|
||||
|
||||
@property
|
||||
def pag_drop_rate(self):
|
||||
return self._pag_drop_rate
|
||||
|
||||
@property
|
||||
def pag_applied_layers(self):
|
||||
return self._pag_applied_layers
|
||||
|
||||
@property
|
||||
def pag_applied_layers_index(self):
|
||||
return self._pag_applied_layers_index
|
||||
@@ -1080,8 +1073,6 @@ class StableDiffusionPAGPipeline(
|
||||
guidance_scale: float = 7.5,
|
||||
pag_scale: float = 0.0,
|
||||
pag_adaptive_scaling: float = 0.0,
|
||||
pag_drop_rate: float = 0.5,
|
||||
pag_applied_layers: List[str] = ["down"], # ['down', 'mid', 'up']
|
||||
pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
@@ -1221,8 +1212,6 @@ class StableDiffusionPAGPipeline(
|
||||
|
||||
self._pag_scale = pag_scale
|
||||
self._pag_adaptive_scaling = pag_adaptive_scaling
|
||||
self._pag_drop_rate = pag_drop_rate
|
||||
self._pag_applied_layers = pag_applied_layers
|
||||
self._pag_applied_layers_index = pag_applied_layers_index
|
||||
|
||||
# 2. Define call parameters
|
||||
@@ -1257,13 +1246,13 @@ class StableDiffusionPAGPipeline(
|
||||
# to avoid doing two forward passes
|
||||
|
||||
# cfg
|
||||
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
|
||||
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
# pag
|
||||
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
|
||||
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
||||
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
|
||||
# both
|
||||
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
|
||||
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
@@ -1306,7 +1295,7 @@ class StableDiffusionPAGPipeline(
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
# 7. Denoising loop
|
||||
if self.do_adversarial_guidance:
|
||||
if self.do_perturbed_attention_guidance:
|
||||
down_layers = []
|
||||
mid_layers = []
|
||||
up_layers = []
|
||||
@@ -1322,6 +1311,29 @@ class StableDiffusionPAGPipeline(
|
||||
else:
|
||||
raise ValueError(f"Invalid layer type: {layer_type}")
|
||||
|
||||
# change attention layer in UNet if use PAG
|
||||
if self.do_perturbed_attention_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
replace_processor = PAGCFGIdentitySelfAttnProcessor()
|
||||
else:
|
||||
replace_processor = PAGIdentitySelfAttnProcessor()
|
||||
|
||||
drop_layers = self.pag_applied_layers_index
|
||||
for drop_layer in drop_layers:
|
||||
try:
|
||||
if drop_layer[0] == "d":
|
||||
down_layers[int(drop_layer[1])].processor = replace_processor
|
||||
elif drop_layer[0] == "m":
|
||||
mid_layers[int(drop_layer[1])].processor = replace_processor
|
||||
elif drop_layer[0] == "u":
|
||||
up_layers[int(drop_layer[1])].processor = replace_processor
|
||||
else:
|
||||
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
|
||||
except IndexError:
|
||||
raise ValueError(
|
||||
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
|
||||
)
|
||||
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -1330,41 +1342,18 @@ class StableDiffusionPAGPipeline(
|
||||
continue
|
||||
|
||||
# cfg
|
||||
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
|
||||
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
# pag
|
||||
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
|
||||
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
# both
|
||||
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
|
||||
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
||||
latent_model_input = torch.cat([latents] * 3)
|
||||
# no
|
||||
else:
|
||||
latent_model_input = latents
|
||||
|
||||
# change attention layer in UNet if use PAG
|
||||
if self.do_adversarial_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
replace_processor = PAGCFGIdentitySelfAttnProcessor()
|
||||
else:
|
||||
replace_processor = PAGIdentitySelfAttnProcessor()
|
||||
|
||||
drop_layers = self.pag_applied_layers_index
|
||||
for drop_layer in drop_layers:
|
||||
try:
|
||||
if drop_layer[0] == "d":
|
||||
down_layers[int(drop_layer[1])].processor = replace_processor
|
||||
elif drop_layer[0] == "m":
|
||||
mid_layers[int(drop_layer[1])].processor = replace_processor
|
||||
elif drop_layer[0] == "u":
|
||||
up_layers[int(drop_layer[1])].processor = replace_processor
|
||||
else:
|
||||
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
|
||||
except IndexError:
|
||||
raise ValueError(
|
||||
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
|
||||
)
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
@@ -1381,14 +1370,14 @@ class StableDiffusionPAGPipeline(
|
||||
# perform guidance
|
||||
|
||||
# cfg
|
||||
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
|
||||
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
|
||||
delta = noise_pred_text - noise_pred_uncond
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * delta
|
||||
|
||||
# pag
|
||||
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
|
||||
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
||||
noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
|
||||
|
||||
signal_scale = self.pag_scale
|
||||
@@ -1400,7 +1389,7 @@ class StableDiffusionPAGPipeline(
|
||||
noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
|
||||
|
||||
# both
|
||||
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
|
||||
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
|
||||
|
||||
signal_scale = self.pag_scale
|
||||
@@ -1458,11 +1447,8 @@ class StableDiffusionPAGPipeline(
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
# change attention layer in UNet if use PAG
|
||||
if self.do_adversarial_guidance:
|
||||
if self.do_perturbed_attention_guidance:
|
||||
drop_layers = self.pag_applied_layers_index
|
||||
for drop_layer in drop_layers:
|
||||
try:
|
||||
@@ -1479,4 +1465,7 @@ class StableDiffusionPAGPipeline(
|
||||
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
Reference in New Issue
Block a user