1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Modification on the PAG community pipeline (re) (#7876)

* edited_pag_implementation

* update

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
Hyoungwon Cho
2024-05-08 11:35:15 +09:00
committed by GitHub
parent 8edaf3b79c
commit c2217142bd

View File

@@ -1,4 +1,5 @@
# Implementation of StableDiffusionPAGPipeline
# Implementation of StableDiffusionPipeline with PAG
# https://ku-cvlab.github.io/Perturbed-Attention-Guidance
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
@@ -134,8 +135,8 @@ class PAGIdentitySelfAttnProcessor:
value = attn.to_v(hidden_states_ptb)
hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
# hidden_states_ptb = value
# hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
hidden_states_ptb = value
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
@@ -1045,7 +1046,7 @@ class StableDiffusionPAGPipeline(
return self._pag_scale
@property
def do_adversarial_guidance(self):
def do_perturbed_attention_guidance(self):
return self._pag_scale > 0
@property
@@ -1056,14 +1057,6 @@ class StableDiffusionPAGPipeline(
def do_pag_adaptive_scaling(self):
return self._pag_adaptive_scaling > 0
@property
def pag_drop_rate(self):
return self._pag_drop_rate
@property
def pag_applied_layers(self):
return self._pag_applied_layers
@property
def pag_applied_layers_index(self):
return self._pag_applied_layers_index
@@ -1080,8 +1073,6 @@ class StableDiffusionPAGPipeline(
guidance_scale: float = 7.5,
pag_scale: float = 0.0,
pag_adaptive_scaling: float = 0.0,
pag_drop_rate: float = 0.5,
pag_applied_layers: List[str] = ["down"], # ['down', 'mid', 'up']
pag_applied_layers_index: List[str] = ["d4"], # ['d4', 'd5', 'm0']
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -1221,8 +1212,6 @@ class StableDiffusionPAGPipeline(
self._pag_scale = pag_scale
self._pag_adaptive_scaling = pag_adaptive_scaling
self._pag_drop_rate = pag_drop_rate
self._pag_applied_layers = pag_applied_layers
self._pag_applied_layers_index = pag_applied_layers_index
# 2. Define call parameters
@@ -1257,13 +1246,13 @@ class StableDiffusionPAGPipeline(
# to avoid doing two forward passes
# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds])
# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds])
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
@@ -1306,7 +1295,7 @@ class StableDiffusionPAGPipeline(
).to(device=device, dtype=latents.dtype)
# 7. Denoising loop
if self.do_adversarial_guidance:
if self.do_perturbed_attention_guidance:
down_layers = []
mid_layers = []
up_layers = []
@@ -1322,6 +1311,29 @@ class StableDiffusionPAGPipeline(
else:
raise ValueError(f"Invalid layer type: {layer_type}")
# change attention layer in UNet if use PAG
if self.do_perturbed_attention_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()
drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -1330,41 +1342,18 @@ class StableDiffusionPAGPipeline(
continue
# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2)
# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 2)
# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
latent_model_input = torch.cat([latents] * 3)
# no
else:
latent_model_input = latents
# change attention layer in UNet if use PAG
if self.do_adversarial_guidance:
if self.do_classifier_free_guidance:
replace_processor = PAGCFGIdentitySelfAttnProcessor()
else:
replace_processor = PAGIdentitySelfAttnProcessor()
drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
if drop_layer[0] == "d":
down_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "m":
mid_layers[int(drop_layer[1])].processor = replace_processor
elif drop_layer[0] == "u":
up_layers[int(drop_layer[1])].processor = replace_processor
else:
raise ValueError(f"Invalid layer type: {drop_layer[0]}")
except IndexError:
raise ValueError(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -1381,14 +1370,14 @@ class StableDiffusionPAGPipeline(
# perform guidance
# cfg
if self.do_classifier_free_guidance and not self.do_adversarial_guidance:
if self.do_classifier_free_guidance and not self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
delta = noise_pred_text - noise_pred_uncond
noise_pred = noise_pred_uncond + self.guidance_scale * delta
# pag
elif not self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif not self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_original, noise_pred_perturb = noise_pred.chunk(2)
signal_scale = self.pag_scale
@@ -1400,7 +1389,7 @@ class StableDiffusionPAGPipeline(
noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb)
# both
elif self.do_classifier_free_guidance and self.do_adversarial_guidance:
elif self.do_classifier_free_guidance and self.do_perturbed_attention_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_text_perturb = noise_pred.chunk(3)
signal_scale = self.pag_scale
@@ -1458,11 +1447,8 @@ class StableDiffusionPAGPipeline(
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
# change attention layer in UNet if use PAG
if self.do_adversarial_guidance:
if self.do_perturbed_attention_guidance:
drop_layers = self.pag_applied_layers_index
for drop_layer in drop_layers:
try:
@@ -1479,4 +1465,7 @@ class StableDiffusionPAGPipeline(
f"Invalid layer index: {drop_layer}. Available layers: {len(down_layers)} down layers, {len(mid_layers)} mid layers, {len(up_layers)} up layers."
)
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)