From 7d756813d4b094b1c19d3890d2afdc4ff54f4f25 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 28 Mar 2023 21:00:49 +0530 Subject: [PATCH] Update the legacy inpainting SD pipeline, to allow calling it with only prompt_embeds (instead of always requiring a prompt) (#2842) Fix error 'required positional argument: prompt' when Legacy Inpaint is called only with prompt_embeds --- .../pipeline_stable_diffusion_inpaint_legacy.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index accbf9674e..feb13d1000 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -521,7 +521,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): @torch.no_grad() def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, image: Union[torch.FloatTensor, PIL.Image.Image] = None, mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, strength: float = 0.8, @@ -611,10 +611,16 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): (nsfw) content, according to the `safety_checker`. """ # 1. Check inputs - self.check_inputs(prompt, strength, callback_steps) + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`