1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Stable Diffusion] Fix padding / truncation (#1226)

* [Stable Diffusion] Fix padding / truncation

* finish
This commit is contained in:
Patrick von Platen
2022-11-13 20:19:55 +01:00
committed by GitHub
parent 8171566163
commit 4c660d16d0
9 changed files with 88 additions and 25 deletions

View File

@@ -248,17 +248,18 @@ class CycleDiffusionPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method

View File

@@ -114,17 +114,19 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)

View File

@@ -161,17 +161,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)

View File

@@ -175,17 +175,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0)

View File

@@ -236,17 +236,18 @@ class StableDiffusionPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method

View File

@@ -244,17 +244,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method

View File

@@ -244,17 +244,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method

View File

@@ -213,17 +213,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
if not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]
# duplicate text embeddings for each generation per prompt, using mps friendly method

View File

@@ -33,9 +33,10 @@ from diffusers import (
UNet2DConditionModel,
UNet2DModel,
VQModel,
logging,
)
from diffusers.utils import floats_tensor, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
@@ -619,6 +620,57 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
assert image.shape == (1, 128, 128, 3)
def test_stable_diffusion_long_prompt(self):
unet = self.dummy_cond_unet
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
do_classifier_free_guidance = True
negative_prompt = None
num_images_per_prompt = 1
logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion")
prompt = 25 * "@"
with CaptureLogger(logger) as cap_logger_3:
text_embeddings_3 = sd_pipe._encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
prompt = 100 * "@"
with CaptureLogger(logger) as cap_logger:
text_embeddings = sd_pipe._encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
negative_prompt = "Hello"
with CaptureLogger(logger) as cap_logger_2:
text_embeddings_2 = sd_pipe._encode_prompt(
prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
)
assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape
assert text_embeddings.shape[1] == 77
assert cap_logger.out == cap_logger_2.out
# 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25
assert cap_logger.out.count("@") == 25
assert cap_logger_3.out == ""
@slow
@require_torch_gpu