1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Community Pipeline] Regional Prompting Pipeline (#6015)

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update examples/community/README.md

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
hako-mikan
2023-12-02 00:22:59 +09:00
committed by GitHub
parent bc1d28c888
commit 46c751e970
2 changed files with 687 additions and 0 deletions

View File

@@ -48,6 +48,7 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
@@ -2524,6 +2525,181 @@ images[0].save("controlnet_and_adapter_inpaint.png")
```
### Regional Prompting Pipeline
This pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to diffusers.
This code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts.
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline1.png)
### Usage
### Sample Code
```
from from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)
rp_args = {
"mode":"rows",
"div": "1;1;1"
}
prompt ="""
green hair twintail BREAK
red blouse BREAK
blue skirt
"""
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=7.5,
height = 768,
width = 512,
num_inference_steps =20,
num_images_per_prompt = 1,
rp_args = rp_args
).images
time = time.strftime(r"%Y%m%d%H%M%S")
i = 1
for image in images:
i += 1
fileName = f'img-{time}-{i+1}.png'
image.save(fileName)
```
### Cols, Rows mode
In the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions.
In this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region.
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline2.png)
```
green hair twintail BREAK
red blouse BREAK
blue skirt
```
### 2-Dimentional division
The prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon`;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits.
```
rp_args = {
"mode":"rows",
"div": "1,2,1,1;2,4,6"
}
prompt ="""
blue sky BREAK
green hair BREAK
book shelf BREAK
terrarium on desk BREAK
orange dress and sofa
"""
```
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline4.png)
### Prompt Mode
There are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the regions is determined after the image generation has begun. This allows us to accommodate compositions and complex regions.
For further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md).
### syntax
```
baseprompt target1 target2 BREAK
effect1, target1 BREAK
effect2 ,target2
```
First, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back.
```
target2 baseprompt target1 BREAK
effect1, target1 BREAK
effect2 ,target2
```
is also effective.
### Sample
In this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions.
```
rp_args = {
"mode":"prompt-ex",
"save_mask":True,
"th": "0.4,0.6,0.6",
}
prompt ="""
a girl in street with shirt, tie, skirt BREAK
red, shirt BREAK
green, tie BREAK
blue , skirt
"""
```
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline3.png)
### threshold
The threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK.
```
a lady ,hair, face BREAK
red, hair BREAK
tanned ,face
```
`threshold : 0.4,0.6`
If only one input is given for multiple regions, they are all assumed to be the same value.
### Prompt and Prompt-EX
The difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled.
### Accuracy
In the case of a 512 x 512 image, Attention mode reduces the size of the region to about 8 x 8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact.
```
girl hair twintail frills,ribbons, dress, face BREAK
girl, ,face
```
### Mask
When an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size.
### Use common prompt
You can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things."For example, if you write as follows:
```
best quality, 3persons in garden, ADDCOMM
a girl white dress BREAK
a boy blue shirt BREAK
an old man red suit
```
If common is enabled, this prompt is converted to the following:
```
best quality, 3persons in garden, a girl white dress BREAK
best quality, 3persons in garden, a boy blue shirt BREAK
best quality, 3persons in garden, an old man red suit
```
### Negative prompt
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
### Parameters
To activate Regional Prompter, it is necessary to enter settings in rp_args. The items that can be set are as follows. rp_args is a dictionary type.
### Input Parameters
Parameters are specified through the `rp_arg`(dictionary type).
```
rp_args = {
"mode":"rows",
"div": "1;1;1"
}
pipe(prompt =prompt, rp_args = rp_args)
```
### Required Parameters
- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt` or `Prompt-Ex`. This parameter is case-insensitive.
- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections.
- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section.
### Optional Parameters
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
## Diffusion Posterior Sampling Pipeline
* Reference paper
```

View File

@@ -0,0 +1,511 @@
import torchvision.transforms.functional as FF
import torch
import torchvision
from typing import Dict, Optional
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
try:
from compel import Compel
except:
Compel = None
KCOMM = "ADDCOMM"
KBRK = "BREAK"
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
r"""
Args for Regional Prompting Pipeline:
rp_args:dict
Required
rp_args["mode"]: cols, rows, prompt, prompt-ex
for cols, rows mode
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
for prompt, prompt-ex mode
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(vae,text_encoder,tokenizer,unet,scheduler,safety_checker,feature_extractor,requires_safety_checker)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: str,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: str = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
rp_args:Dict[str,str] = None,
):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt
if negative_prompt is None: negative_prompt = "" if type(prompt) == str else [""] * len(prompt)
device = self._execution_device
regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt]
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts,num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts,num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
def getcompelembs(prps):
embl = []
for prp in prps:
embl.append(compel.build_conditioning_tensor(prp))
return torch.cat(embl)
conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = self.encode_prompt(n_prompts, device, 1, True)[0] if cn else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
embs = n_embs = None
if not active:
pcallback = None
mode = None
else:
if any(x in rp_args["mode"].upper() for x in ["COL","ROW"]):
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
ocells,icells,regions = make_cells(rp_args["div"])
elif "PRO" in rp_args["mode"].upper():
regions = len(all_prompts_p[0])
mode = "PROMPT"
reset_attnmaps(self)
self.ex = "EX" in rp_args["mode"].upper()
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
thresholds = [float(x) for x in rp_args["th"].split(",")]
orig_hw = (height,width)
revers = True
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor,selfs=None):
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
self.step = step
if len(self.attnmaps_sizes) > 3:
self.history[step] = self.attnmaps.copy()
for hw in self.attnmaps_sizes:
allmasks = []
basemasks = [None] * batch
for tt, th in zip(target_tokens, thresholds):
for b in range(batch):
key = f"{tt}-{b}"
_, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
mask = mask.unsqueeze(0).unsqueeze(-1)
if self.ex:
allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
allmasks.append(mask)
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
basemasks = [1 -mask for mask in basemasks]
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
allmasks = basemasks + allmasks
self.attnmasks[hw] = torch.cat(allmasks)
self.maskready = True
return latents
def hook_forward(module):
#diffusers==0.23.2
def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
attn = module
xshape = hidden_states.shape
self.hw = (h,w) = split_dims(xshape[1], *orig_hw)
if revers:
nx,px = hidden_states.chunk(2)
else:
px,nx = hidden_states.chunk(2)
if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)],0)
encoder_hidden_states = torch.cat([conds]+[unconds])
else:
hidden_states = torch.cat([px for i in range(regions)] + [nx],0)
encoder_hidden_states = torch.cat([conds]+[unconds])
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = scaled_dot_product_attention(
self, query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, getattn = "PRO" in mode
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
#### Regional Prompting Col/Row mode
if any(x in mode for x in ["COL", "ROW"]):
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px,nx] if cn else [px]
for out in outs:
c = 0
for i,ocell in enumerate(ocells):
for icell in icells[i]:
if "ROW" in mode:
out[0:batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:] = out[c*batch:(c+1)*batch,int(h*ocell[0]):int(h*ocell[1]),int(w*icell[0]):int(w*icell[1]),:]
else:
out[0:batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:] = out[c*batch:(c+1)*batch,int(h*icell[0]):int(h*icell[1]),int(w*ocell[0]):int(w*ocell[1]),:]
c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0)
hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode
elif "PRO" in mode:
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
if (h,w) in self.attnmasks and self.maskready:
def mask(input):
out = torch.multiply(input,self.attnmasks[(h,w)])
for b in range(batch):
for r in range(1, regions):
out[b] = out[b] + out[r * batch + b]
return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx,px],0) if revers else torch.cat([px,nx],0)
return hidden_states
return forward
def hook_forwards(root_module: torch.nn.Module):
for name, module in root_module.named_modules():
if "attn2" in name and module.__class__.__name__ == "Attention":
module.forward = hook_forward(module)
hook_forwards(self.unet)
output = StableDiffusionPipeline(**self.components)(
prompt=prompt,
prompt_embeds=embs,
negative_prompt=negative_prompt,
negative_prompt_embeds=n_embs,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback_on_step_end = pcallback
)
if "save_mask" in rp_args:
save_mask = rp_args["save_mask"]
else:
save_mask = False
if mode == "PROMPT" and save_mask: saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
return output
### Make prompt list for each regions
def promptsmaker(prompts,batch):
out_p = []
plen = len(prompts)
for prompt in prompts:
add = ""
if KCOMM in prompt:
add, prompt = prompt.split(KCOMM)
add = add + " "
prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts])
out = [None]*batch*len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
start = (p + r * plen) * batch
out[start : start + batch]= [pr] * batch #P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return out, out_p
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def make_cells(ratios):
if ";" not in ratios and "," in ratios:ratios = ratios.replace(",",";")
ratios = ratios.split(";")
ratios = [inratios.split(",") for inratios in ratios]
icells = []
ocells = []
def startend(cells,array):
current_start = 0
array = [float(x) for x in array]
for value in array:
end = current_start + (value / sum(array))
cells.append([current_start, end])
current_start = end
startend(ocells,[r[0] for r in ratios])
for inratios in ratios:
if 2 > len(inratios):
icells.append([[0,1]])
else:
add = []
startend(add,inratios[1:])
icells.append(add)
return ocells, icells, sum(len(cell) for cell in icells)
def make_emblist(self, prompts):
with torch.no_grad():
tokens = self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype = self.dtype)
return embs
import math
def split_dims(xs, height, width):
xs = xs
def repeat_div(x,y):
while y > 0:
x = math.ceil(x / 2)
y = y - 1
return x
scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
dsh = repeat_div(height,scale)
dsw = repeat_div(width,scale)
return dsh,dsw
##### for prompt mode
def get_attn_maps(self,attn):
height,width = self.hw
target_tokens = self.target_tokens
if (height,width) not in self.attnmaps_sizes:
self.attnmaps_sizes.append((height,width))
for b in range(self.batch):
for t in target_tokens:
power = self.power
add = attn[b,:,:,t[0]:t[0]+len(t)]**(power)*(self.attnmaps_sizes.index((height,width)) + 1)
add = torch.sum(add,dim = 2)
key = f"{t}-{b}"
if key not in self.attnmaps:
self.attnmaps[key] = add
else:
if self.attnmaps[key].shape[1] != add.shape[1]:
add = add.view(8,height,width)
add = FF.resize(add,self.attnmaps_sizes[0],antialias=None)
add = add.reshape_as(self.attnmaps[key])
self.attnmaps[key] = self.attnmaps[key] + add
def reset_attnmaps(self): # init parameters in every batch
self.step = 0
self.attnmaps = {} #maked from attention maps
self.attnmaps_sizes =[] #height,width set of u-net blocks
self.attnmasks = {} #maked from attnmaps for regions
self.maskready = False
self.history = {}
def saveattnmaps(self,output,h,w,th,step,regions):
masks = []
for i, mask in enumerate(self.history[step].values()):
img, _ , mask = makepmask(self, mask, h, w, th[i % len(th)], step)
if self.ex:
masks = [x - mask for x in masks]
masks.append(mask)
if len(masks) == regions - 1:
output.images.extend([FF.to_pil_image(mask) for mask in masks])
masks = []
else:
output.images.append(img)
def makepmask(self, mask, h, w, th, step): # make masks from attention cache return [for preview, for attention, for Latent]
th = th - step * 0.005
if 0.05 >= th: th = 0.05
mask = torch.mean(mask,dim=0)
mask = mask / mask.max().item()
mask = torch.where(mask > th ,1,0)
mask = mask.float()
mask = mask.view(1,*self.attnmaps_sizes[0])
img = FF.to_pil_image(mask)
img = img.resize((w,h))
mask = FF.resize(mask,(h,w),interpolation=FF.InterpolationMode.NEAREST,antialias=None)
lmask = mask
mask = mask.reshape(h*w)
mask = torch.where(mask > 0.1 ,1,0)
return img, mask, lmask
def tokendealer(self, all_prompts):
for prompts in all_prompts:
targets =[p.split(",")[-1] for p in prompts[1:]]
tt = []
for target in targets:
ptokens = (self.tokenizer(prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0]
ttokens = (self.tokenizer(target, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors='pt').input_ids)[0]
tlist = []
for t in range(ttokens.shape[0] -2):
for p in range(ptokens.shape[0]):
if ttokens[t + 1] == ptokens[p]:
tlist.append(p)
if tlist != [] : tt.append(tlist)
return tt
def scaled_dot_product_attention(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn = False) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype,device=self.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if getattn: get_attn_maps(self,attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value