1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffusion (#925)

[Community Pipelines] fix pad_tokens_and_weights in lpw_stable_diffusion
This commit is contained in:
SkyTNT
2022-10-21 01:26:04 +08:00
committed by GitHub
parent 6f6eef747c
commit ba74a8be7a
2 changed files with 90 additions and 34 deletions

View File

@@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
@@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word).input_ids[1:-1]
text_token += token
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range((len(weights[i]) - 1) // chunk_length + 1):
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
@@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
def get_unweighted_text_embeddings(
pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
pipe: DiffusionPipeline,
text_input: torch.Tensor,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
@@ -285,7 +289,8 @@ def get_weighted_text_embeddings(
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
max_embeddings_multiples,
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -317,12 +322,18 @@ def get_weighted_text_embeddings(
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
@@ -632,16 +643,29 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
self.unet.in_channels,
height // 8,
width // 8,
)
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
latents = torch.randn(
latents_shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
@@ -684,11 +708,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
# add noise to latents using the timesteps
if self.device.type == "mps":
# randn does not exist on mps
noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
noise = torch.randn(
init_latents.shape,
generator=generator,
device="cpu",
dtype=latents_dtype,
).to(self.device)
else:
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
noise = torch.randn(
init_latents.shape,
generator=generator,
device=self.device,
dtype=latents_dtype,
)
latents = self.scheduler.add_noise(init_latents, noise, timesteps)
t_start = max(num_inference_steps - init_timestep + offset, 0)
@@ -741,7 +773,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
self.device
)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
images=image,
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
)
else:
has_nsfw_concept = None

View File

@@ -130,6 +130,7 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
"""
tokens = []
weights = []
truncated = False
for text in prompt:
texts_and_weights = parse_prompt_attention(text)
text_token = []
@@ -138,21 +139,21 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
# tokenize and discard the starting and the ending token
token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
text_token += list(token)
# copy the weight by length of token
text_weight += [weight] * len(token)
# stop if the text is too long (longer than truncation limit)
if len(text_token) > max_length:
truncated = True
break
# truncate
if len(text_token) > max_length:
truncated = True
text_token = text_token[:max_length]
text_weight = text_weight[:max_length]
tokens.append(text_token)
weights.append(text_weight)
if truncated:
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
return tokens, weights
@@ -171,9 +172,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
if len(weights[i]) == 0:
w = [1.0] * weights_length
else:
for j in range((len(weights[i]) - 1) // chunk_length + 1):
for j in range(max_embeddings_multiples):
w.append(1.0) # weight for starting token in this chunk
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
w.append(1.0) # weight for ending token in this chunk
w += [1.0] * (weights_length - len(w))
weights[i] = w[:]
@@ -182,7 +183,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
def get_unweighted_text_embeddings(
pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True
pipe,
text_input: np.array,
chunk_length: int,
no_boseos_middle: Optional[bool] = True,
):
"""
When the length of tokens is a multiple of the capacity of the text encoder,
@@ -276,7 +280,10 @@ def get_weighted_text_embeddings(
uncond_tokens = [
token[1:-1]
for token in pipe.tokenizer(
uncond_prompt, max_length=max_length, truncation=True, return_tensors="np"
uncond_prompt,
max_length=max_length,
truncation=True,
return_tensors="np",
).input_ids
]
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
@@ -287,7 +294,8 @@ def get_weighted_text_embeddings(
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
max_embeddings_multiples = min(
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
max_embeddings_multiples,
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
)
max_embeddings_multiples = max(1, max_embeddings_multiples)
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -319,12 +327,18 @@ def get_weighted_text_embeddings(
# get the embeddings
text_embeddings = get_unweighted_text_embeddings(
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
prompt_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
if uncond_prompt is not None:
uncond_embeddings = get_unweighted_text_embeddings(
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
pipe,
uncond_tokens,
pipe.tokenizer.model_max_length,
no_boseos_middle=no_boseos_middle,
)
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
@@ -559,7 +573,12 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
noise = None
if init_image is None:
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
latents_shape = (
batch_size * num_images_per_prompt,
4,
height // 8,
width // 8,
)
if latents is None:
latents = generator.randn(*latents_shape).astype(latents_dtype)
@@ -625,7 +644,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
# predict the noise residual
noise_pred = self.unet(
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
sample=latent_model_input,
timestep=np.array([t]),
encoder_hidden_states=text_embeddings,
)
noise_pred = noise_pred[0]
@@ -640,7 +661,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
if mask is not None:
# masking
init_latents_proper = self.scheduler.add_noise(
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
torch.from_numpy(init_latents_orig),
torch.from_numpy(noise),
torch.tensor([t]),
).numpy()
latents = (init_latents_proper * mask) + (latents * (1 - mask))