From d43972ae71ffbc94dee7045ecfe1e5f7c6ac329e Mon Sep 17 00:00:00 2001 From: "Jorge C. Gomes" Date: Tue, 7 Feb 2023 08:10:24 +0000 Subject: [PATCH] Fixes prompt input checks in StableDiffusion img2img pipeline (#2206) * Fixes prompt input checks in img2img Allows providing prompt_embeds instead of the prompt, which is not currently possible as the first check fails. This becomes the same as the function found in https://github.com/huggingface/diffusers/blob/8267c7844504b55366525169187767ef92d1f499/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L393 * Continues the fix This also needs to be fixed. Becomes consistent with https://github.com/huggingface/diffusers/blob/8267c7844504b55366525169187767ef92d1f499/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L558 I've now tested this implementation, and it produces the expected results. --- .../pipeline_stable_diffusion_img2img.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 09eef31cfe..3f988c31dd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -428,9 +428,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -623,7 +620,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`