1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Fix] Fix Regional Prompting Pipeline (#6188)

* Update regional_prompting_stable_diffusion.py

* reformat

* reformat

* reformat

* reformat

* reformat

* reformat

* reformat

* regormat

* reformat

* reformat

* reformat

* reformat

* Update regional_prompting_stable_diffusion.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
hako-mikan
2023-12-20 14:07:19 +09:00
committed by GitHub
parent 5433962992
commit ff43dba7ea

View File

@@ -73,7 +73,14 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
)
self.register_modules(
vae=vae,
@@ -102,22 +109,22 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return_dict: bool = True,
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
if negative_prompt is None:
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
device = self._execution_device
regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
prompts = prompt if isinstance(prompt, list) else [prompt]
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
equal = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
@@ -129,7 +136,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
return torch.cat(embl)
conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
unconds = getcompelembs(all_n_prompts_cn)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None
@@ -137,7 +144,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0]
if cn
if equal
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
embs = n_embs = None
@@ -206,8 +213,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
else:
px, nx = hidden_states.chunk(2)
if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
if equal:
hidden_states = torch.cat(
[px for i in range(regions)] + [nx for i in range(regions)],
0,
)
encoder_hidden_states = torch.cat([conds] + [unconds])
else:
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
@@ -289,9 +299,9 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
if any(x in mode for x in ["COL", "ROW"]):
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px, nx] if cn else [px]
px = reshaped[0:center] if equal else reshaped[0:-batch]
nx = reshaped[center:] if equal else reshaped[-batch:]
outs = [px, nx] if equal else [px]
for out in outs:
c = 0
for i, ocell in enumerate(ocells):
@@ -321,15 +331,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
:,
]
c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode
elif "PRO" in mode:
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
px, nx = (
torch.chunk(hidden_states) if equal else hidden_states[0:-batch],
hidden_states[-batch:],
)
if (h, w) in self.attnmasks and self.maskready:
@@ -340,8 +351,8 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
out[b] = out[b] + out[r * batch + b]
return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
px, nx = (mask(px), mask(nx)) if equal else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if equal else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states
@@ -378,7 +389,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
save_mask = False
if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
saveattnmaps(
self,
output,
height,
width,
thresholds,
num_inference_steps // 2,
regions,
)
return output
@@ -437,7 +456,11 @@ def make_cells(ratios):
def make_emblist(self, prompts):
with torch.no_grad():
tokens = self.tokenizer(
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs
@@ -563,7 +586,15 @@ def tokendealer(self, all_prompts):
def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
self,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
getattn=False,
) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)