mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Community Pipeline] Add some feature for regional prompting pipeline (#9874)
* [Fix] fix bugs of regional_prompting pipeline * [Feat] add base prompt feature * [Fix] fix __init__ pipeline error * [Fix] delete unused args * [Fix] improve string handling * [Docs] docs to use_base in regional_prompting * make style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -3379,6 +3379,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
|
||||
best quality, 3persons in garden, an old man red suit
|
||||
```
|
||||
|
||||
### Use base prompt
|
||||
|
||||
You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.
|
||||
|
||||
```
|
||||
2d animation style ADDBASE
|
||||
masterpiece, high quality ADDCOMM
|
||||
(blue sky)++ BREAK
|
||||
green hair twintail BREAK
|
||||
book shelf BREAK
|
||||
messy desk BREAK
|
||||
orange++ dress and sofa
|
||||
```
|
||||
|
||||
### Negative prompt
|
||||
|
||||
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
|
||||
@@ -3409,6 +3423,7 @@ pipe(prompt=prompt, rp_args=rp_args)
|
||||
### Optional Parameters
|
||||
|
||||
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
|
||||
- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`
|
||||
|
||||
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
|
||||
|
||||
|
||||
@@ -3,13 +3,12 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as FF
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
|
||||
try:
|
||||
@@ -17,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
Compel = None
|
||||
|
||||
KBASE = "ADDBASE"
|
||||
KCOMM = "ADDCOMM"
|
||||
KBRK = "BREAK"
|
||||
|
||||
@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
Optional
|
||||
rp_args["save_mask"]: True/False (save masks in prompt mode)
|
||||
rp_args["power"]: int (power for attention maps in prompt mode)
|
||||
rp_args["base_ratio"]:
|
||||
float (Sets the ratio of the base prompt)
|
||||
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
|
||||
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
|
||||
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -70,6 +75,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -80,6 +86,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
image_encoder,
|
||||
requires_safety_checker,
|
||||
)
|
||||
self.register_modules(
|
||||
@@ -90,6 +97,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -110,17 +118,40 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
rp_args: Dict[str, str] = None,
|
||||
):
|
||||
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
|
||||
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
|
||||
|
||||
device = self._execution_device
|
||||
regions = 0
|
||||
|
||||
self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
|
||||
self.power = int(rp_args["power"]) if "power" in rp_args else 1
|
||||
|
||||
prompts = prompt if isinstance(prompt, list) else [prompt]
|
||||
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
|
||||
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
|
||||
self.batch = batch = num_images_per_prompt * len(prompts)
|
||||
|
||||
if use_base:
|
||||
bases = prompts.copy()
|
||||
n_bases = n_prompts.copy()
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
parts = prompt.split(KBASE)
|
||||
if len(parts) == 2:
|
||||
bases[i], prompts[i] = parts
|
||||
elif len(parts) > 2:
|
||||
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
|
||||
for i, prompt in enumerate(n_prompts):
|
||||
n_parts = prompt.split(KBASE)
|
||||
if len(n_parts) == 2:
|
||||
n_bases[i], n_prompts[i] = n_parts
|
||||
elif len(n_parts) > 2:
|
||||
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")
|
||||
|
||||
all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
|
||||
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)
|
||||
|
||||
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
|
||||
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
|
||||
|
||||
@@ -137,8 +168,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
conds = getcompelembs(all_prompts_cn)
|
||||
unconds = getcompelembs(all_n_prompts_cn)
|
||||
embs = getcompelembs(prompts)
|
||||
n_embs = getcompelembs(n_prompts)
|
||||
base_embs = getcompelembs(all_bases_cn) if use_base else None
|
||||
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
|
||||
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
|
||||
embs = getcompelembs(prompts) if not use_base else base_embs
|
||||
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs
|
||||
|
||||
if use_base and self.base_ratio > 0:
|
||||
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
|
||||
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
|
||||
|
||||
prompt = negative_prompt = None
|
||||
else:
|
||||
conds = self.encode_prompt(prompts, device, 1, True)[0]
|
||||
@@ -147,6 +186,18 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
if equal
|
||||
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
|
||||
)
|
||||
|
||||
if use_base and self.base_ratio > 0:
|
||||
base_embs = self.encode_prompt(bases, device, 1, True)[0]
|
||||
base_n_embs = (
|
||||
self.encode_prompt(n_bases, device, 1, True)[0]
|
||||
if equal
|
||||
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
|
||||
)
|
||||
|
||||
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
|
||||
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
|
||||
|
||||
embs = n_embs = None
|
||||
|
||||
if not active:
|
||||
@@ -225,8 +276,6 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -247,16 +296,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -283,7 +331,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
|
||||
add = ""
|
||||
if KCOMM in prompt:
|
||||
add, prompt = prompt.split(KCOMM)
|
||||
add = add + " "
|
||||
prompts = prompt.split(KBRK)
|
||||
out_p.append([add + p for p in prompts])
|
||||
add = add.strip() + " "
|
||||
prompts = [p.strip() for p in prompt.split(KBRK)]
|
||||
out_p.append([add + p for i, p in enumerate(prompts)])
|
||||
out = [None] * batch * len(out_p[0]) * len(out_p)
|
||||
for p, prs in enumerate(out_p): # inputs prompts
|
||||
for r, pr in enumerate(prs): # prompts for regions
|
||||
@@ -449,7 +497,6 @@ def make_cells(ratios):
|
||||
add = []
|
||||
startend(add, inratios[1:])
|
||||
icells.append(add)
|
||||
|
||||
return ocells, icells, sum(len(cell) for cell in icells)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user