mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Community Pipelines] Fix pad_tokens_and_weights in lpw_stable_diffusion (#925)
[Community Pipelines] fix pad_tokens_and_weights in lpw_stable_diffusion
This commit is contained in:
@@ -132,6 +132,7 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
|
||||
"""
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
@@ -140,21 +141,21 @@ def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_len
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||
text_token += token
|
||||
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
# truncate
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
text_token = text_token[:max_length]
|
||||
text_weight = text_weight[:max_length]
|
||||
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
|
||||
@@ -173,9 +174,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
|
||||
if len(weights[i]) == 0:
|
||||
w = [1.0] * weights_length
|
||||
else:
|
||||
for j in range((len(weights[i]) - 1) // chunk_length + 1):
|
||||
for j in range(max_embeddings_multiples):
|
||||
w.append(1.0) # weight for starting token in this chunk
|
||||
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
|
||||
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
||||
w.append(1.0) # weight for ending token in this chunk
|
||||
w += [1.0] * (weights_length - len(w))
|
||||
weights[i] = w[:]
|
||||
@@ -184,7 +185,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True
|
||||
pipe: DiffusionPipeline,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
@@ -285,7 +289,8 @@ def get_weighted_text_embeddings(
|
||||
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
||||
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
@@ -317,12 +322,18 @@ def get_weighted_text_embeddings(
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
||||
pipe,
|
||||
prompt_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
||||
if uncond_prompt is not None:
|
||||
uncond_embeddings = get_unweighted_text_embeddings(
|
||||
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
||||
pipe,
|
||||
uncond_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
||||
|
||||
@@ -632,16 +643,29 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
self.unet.in_channels,
|
||||
height // 8,
|
||||
width // 8,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
dtype=latents_dtype,
|
||||
).to(self.device)
|
||||
else:
|
||||
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
dtype=latents_dtype,
|
||||
)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
@@ -684,11 +708,19 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
# add noise to latents using the timesteps
|
||||
if self.device.type == "mps":
|
||||
# randn does not exist on mps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device="cpu", dtype=latents_dtype).to(
|
||||
self.device
|
||||
)
|
||||
noise = torch.randn(
|
||||
init_latents.shape,
|
||||
generator=generator,
|
||||
device="cpu",
|
||||
dtype=latents_dtype,
|
||||
).to(self.device)
|
||||
else:
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
|
||||
noise = torch.randn(
|
||||
init_latents.shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
dtype=latents_dtype,
|
||||
)
|
||||
latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
@@ -741,7 +773,8 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
images=image,
|
||||
clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
@@ -130,6 +130,7 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
|
||||
"""
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
@@ -138,21 +139,21 @@ def get_prompts_with_weights(pipe, prompt: List[str], max_length: int):
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word, return_tensors="np").input_ids[0, 1:-1]
|
||||
text_token += list(token)
|
||||
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
|
||||
# truncate
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
text_token = text_token[:max_length]
|
||||
text_weight = text_weight[:max_length]
|
||||
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
||||
return tokens, weights
|
||||
|
||||
|
||||
@@ -171,9 +172,9 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
|
||||
if len(weights[i]) == 0:
|
||||
w = [1.0] * weights_length
|
||||
else:
|
||||
for j in range((len(weights[i]) - 1) // chunk_length + 1):
|
||||
for j in range(max_embeddings_multiples):
|
||||
w.append(1.0) # weight for starting token in this chunk
|
||||
w += weights[i][j * chunk_length : min(len(weights[i]), (j + 1) * chunk_length)]
|
||||
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
||||
w.append(1.0) # weight for ending token in this chunk
|
||||
w += [1.0] * (weights_length - len(w))
|
||||
weights[i] = w[:]
|
||||
@@ -182,7 +183,10 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
pipe, text_input: np.array, chunk_length: int, no_boseos_middle: Optional[bool] = True
|
||||
pipe,
|
||||
text_input: np.array,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: Optional[bool] = True,
|
||||
):
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
@@ -276,7 +280,10 @@ def get_weighted_text_embeddings(
|
||||
uncond_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(
|
||||
uncond_prompt, max_length=max_length, truncation=True, return_tensors="np"
|
||||
uncond_prompt,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
).input_ids
|
||||
]
|
||||
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
||||
@@ -287,7 +294,8 @@ def get_weighted_text_embeddings(
|
||||
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
||||
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples, (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
@@ -319,12 +327,18 @@ def get_weighted_text_embeddings(
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
pipe, prompt_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
||||
pipe,
|
||||
prompt_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
prompt_weights = np.array(prompt_weights, dtype=text_embeddings.dtype)
|
||||
if uncond_prompt is not None:
|
||||
uncond_embeddings = get_unweighted_text_embeddings(
|
||||
pipe, uncond_tokens, pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle
|
||||
pipe,
|
||||
uncond_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
uncond_weights = np.array(uncond_weights, dtype=uncond_embeddings.dtype)
|
||||
|
||||
@@ -559,7 +573,12 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
noise = None
|
||||
|
||||
if init_image is None:
|
||||
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
latents_shape = (
|
||||
batch_size * num_images_per_prompt,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = generator.randn(*latents_shape).astype(latents_dtype)
|
||||
@@ -625,7 +644,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
|
||||
sample=latent_model_input,
|
||||
timestep=np.array([t]),
|
||||
encoder_hidden_states=text_embeddings,
|
||||
)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
@@ -640,7 +661,9 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
||||
if mask is not None:
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
torch.from_numpy(init_latents_orig), torch.from_numpy(noise), torch.tensor([t])
|
||||
torch.from_numpy(init_latents_orig),
|
||||
torch.from_numpy(noise),
|
||||
torch.tensor([t]),
|
||||
).numpy()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user