diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 50f519c77f..528dd33794 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -248,17 +248,18 @@ class CycleDiffusionPipeline(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index d1e2704fce..eceefea874 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -114,17 +114,19 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index a09dfe751f..483b5fd2d3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -161,17 +161,19 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 6c226dd432..8e5c201319 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -175,17 +175,19 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] text_embeddings = np.repeat(text_embeddings, num_images_per_prompt, axis=0) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ed5246fa91..450fbbfb17 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -236,17 +236,18 @@ class StableDiffusionPipeline(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index b4ecf600b1..98c813eed1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -244,17 +244,18 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index f42325261a..3f08f6edae 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -244,17 +244,18 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 89d40f7a79..612aa3c126 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -213,17 +213,18 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): prompt, padding="max_length", max_length=self.tokenizer.model_max_length, + truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + if not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 6e1071124c..87d238c869 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -33,9 +33,10 @@ from diffusers import ( UNet2DConditionModel, UNet2DModel, VQModel, + logging, ) from diffusers.utils import floats_tensor, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from ...test_pipelines_common import PipelineTesterMixin @@ -619,6 +620,57 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): assert image.shape == (1, 128, 128, 3) + def test_stable_diffusion_long_prompt(self): + unet = self.dummy_cond_unet + scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + do_classifier_free_guidance = True + negative_prompt = None + num_images_per_prompt = 1 + logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion") + + prompt = 25 * "@" + with CaptureLogger(logger) as cap_logger_3: + text_embeddings_3 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + prompt = 100 * "@" + with CaptureLogger(logger) as cap_logger: + text_embeddings = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + negative_prompt = "Hello" + with CaptureLogger(logger) as cap_logger_2: + text_embeddings_2 = sd_pipe._encode_prompt( + prompt, torch_device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + assert text_embeddings_3.shape == text_embeddings_2.shape == text_embeddings.shape + assert text_embeddings.shape[1] == 77 + + assert cap_logger.out == cap_logger_2.out + # 100 - 77 + 1 (BOS token) + 1 (EOS token) = 25 + assert cap_logger.out.count("@") == 25 + assert cap_logger_3.out == "" + @slow @require_torch_gpu