mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Perturbed-Attention Guidance (#7512)
* pag_initial * pag_docs * edit_docs * custom * typo * delete_docs * whitespace * make style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -3743,3 +3743,80 @@ onestep_image = pipe(prompt, num_inference_steps=1).images[0]
|
||||
# Multistep sampling
|
||||
multistep_image = pipe(prompt, num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
|
||||
|
||||
This implementation is based on [Diffusers](https://huggingface.co/docs/diffusers/index). StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).
|
||||
|
||||
## Example Usage
|
||||
|
||||
```
|
||||
import os
|
||||
import torch
|
||||
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="hyoungwoncho/sd_perturbed_attention_guidance",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
device="cuda"
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pag_scale = 5.0
|
||||
pag_applied_layers_index = ['m0']
|
||||
|
||||
batch_size = 4
|
||||
seed=10
|
||||
|
||||
base_dir = "./results/"
|
||||
grid_dir = base_dir + "/pag" + str(pag_scale) + "/"
|
||||
|
||||
if not os.path.exists(grid_dir):
|
||||
os.makedirs(grid_dir)
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
latent_input = randn_tensor(shape=(batch_size,4,64,64),generator=None, device=device, dtype=torch.float16)
|
||||
|
||||
output_baseline = pipe(
|
||||
"",
|
||||
width=512,
|
||||
height=512,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=0.0,
|
||||
pag_scale=0.0,
|
||||
pag_applied_layers_index=pag_applied_layers_index,
|
||||
num_images_per_prompt=batch_size,
|
||||
latents=latent_input
|
||||
).images
|
||||
|
||||
output_pag = pipe(
|
||||
"",
|
||||
width=512,
|
||||
height=512,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=0.0,
|
||||
pag_scale=5.0,
|
||||
pag_applied_layers_index=pag_applied_layers_index,
|
||||
num_images_per_prompt=batch_size,
|
||||
latents=latent_input
|
||||
).images
|
||||
|
||||
grid_image = make_image_grid(output_baseline + output_pag, rows=2, cols=batch_size)
|
||||
grid_image.save(grid_dir + "sample.png")
|
||||
```
|
||||
|
||||
## PAG Parameters
|
||||
|
||||
pag_scale : gudiance scale of PAG (ex: 5.0)
|
||||
|
||||
pag_applied_layers_index : index of the layer to apply perturbation (ex: ['m0'])
|
||||
|
||||
1477
examples/community/pipeline_stable_diffusion_pag.py
Normal file
1477
examples/community/pipeline_stable_diffusion_pag.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user