From 32ea2142c056fae722b0cabaa799697a861cd039 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 2 Jun 2023 08:57:20 +0100 Subject: [PATCH] [Kandinsky] Improve kandinsky API a bit (#3636) * Improve docs * up * Update docs/source/en/api/pipelines/kandinsky.mdx * up * up * correct more * further improve * Update docs/source/en/api/pipelines/kandinsky.mdx Co-authored-by: YiYi Xu --------- Co-authored-by: YiYi Xu --- docs/source/en/api/pipelines/kandinsky.mdx | 203 +++++++++++------- .../pipelines/kandinsky/pipeline_kandinsky.py | 11 +- .../kandinsky/pipeline_kandinsky_img2img.py | 11 +- .../kandinsky/pipeline_kandinsky_inpaint.py | 11 +- .../kandinsky/pipeline_kandinsky_prior.py | 53 +++-- tests/pipelines/kandinsky/test_kandinsky.py | 6 +- .../kandinsky/test_kandinsky_img2img.py | 6 +- .../kandinsky/test_kandinsky_inpaint.py | 6 +- .../kandinsky/test_kandinsky_prior.py | 2 +- tests/pipelines/test_pipelines_common.py | 2 +- 10 files changed, 187 insertions(+), 124 deletions(-) diff --git a/docs/source/en/api/pipelines/kandinsky.mdx b/docs/source/en/api/pipelines/kandinsky.mdx index b5b4f0f064..b94937e4af 100644 --- a/docs/source/en/api/pipelines/kandinsky.mdx +++ b/docs/source/en/api/pipelines/kandinsky.mdx @@ -19,81 +19,78 @@ The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene55 ## Available Pipelines: -| Pipeline | Tasks | Colab -|---|---|:---:| -| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* | - | -| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* | - | -| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* | - | +| Pipeline | Tasks | +|---|---| +| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* | +| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* | +| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* | ## Usage example -In the following, we will walk you through some cool examples of using the Kandinsky pipelines to create some visually aesthetic artwork. +In the following, we will walk you through some examples of how to use the Kandinsky pipelines to create some visually aesthetic artwork. ### Text-to-Image Generation -For text-to-image generation, we need to use both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. The first step is to encode text prompts with CLIP and then diffuse the CLIP text embeddings to CLIP image embeddings, as first proposed in [DALL-E 2](https://cdn.openai.com/papers/dall-e-2.pdf). Let's throw a fun prompt at Kandinsky to see what it comes up with :) +For text-to-image generation, we need to use both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. +The first step is to encode text prompts with CLIP and then diffuse the CLIP text embeddings to CLIP image embeddings, +as first proposed in [DALL-E 2](https://cdn.openai.com/papers/dall-e-2.pdf). +Let's throw a fun prompt at Kandinsky to see what it comes up with. -```python +```py prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting" -negative_prompt = "low quality, bad quality" ``` -We will pass both the `prompt` and `negative_prompt` to our prior diffusion pipeline. In contrast to other diffusion pipelines, such as Stable Diffusion, the `prompt` and `negative_prompt` shall be passed separately so that we can retrieve a CLIP image embedding for each prompt input. You can use `guidance_scale`, and `num_inference_steps` arguments to guide this process, just like how you would normally do with all other pipelines in diffusers. +First, let's instantiate the prior pipeline and the text-to-image pipeline. Both +pipelines are diffusion models. -```python -from diffusers import KandinskyPriorPipeline + +```py +from diffusers import DiffusionPipeline import torch -# create prior -pipe_prior = KandinskyPriorPipeline.from_pretrained( - "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 -) +pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16) pipe_prior.to("cuda") -generator = torch.Generator(device="cuda").manual_seed(12) -image_emb = pipe_prior( - prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt -).images - -zero_image_emb = pipe_prior( - negative_prompt, guidance_scale=1.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt -).images +t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +t2i_pipe.to("cuda") ``` -Once we create the image embedding, we can use [`KandinskyPipeline`] to generate images. +Now we pass the prompt through the prior to generate image embeddings. The prior +returns both the image embeddings corresponding to the prompt and negative/unconditional image +embeddings corresponding to an empty string. -```python -from PIL import Image -from diffusers import KandinskyPipeline +```py +generator = torch.Generator(device="cuda").manual_seed(12) +image_embeds, negative_image_embeds = pipe_prior(prompt, generator=generator).to_tuple() +``` + + + +The text-to-image pipeline expects both `image_embeds`, `negative_image_embeds` and the original +`prompt` as the text-to-image pipeline uses another text encoder to better guide the second diffusion +process of `t2i_pipe`. + +By default, the prior returns unconditioned negative image embeddings corresponding to the negative prompt of `""`. +For better results, you can also pass a `negative_prompt` to the prior. This will increase the effective batch size +of the prior by a factor of 2. + +```py +prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting" +negative_prompt = "low quality, bad quality" + +image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple() +``` + + -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols +Next, we can pass the embeddings as well as the prompt to the text-to-image pipeline. Remember that +in case you are using a customized negative prompt, that you should pass this one also to the text-to-image pipelines +with `negative_prompt=negative_prompt`: - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - grid_w, grid_h = grid.size - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - - -# create diffuser pipeline -pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) -pipe.to("cuda") - -images = pipe( - prompt, - image_embeds=image_emb, - negative_image_embeds=zero_image_emb, - num_images_per_prompt=2, - height=768, - width=768, - num_inference_steps=100, - guidance_scale=4.0, - generator=generator, -).images +```py +image = t2i_pipe(prompt, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0] +image.save("cheeseburger_monster.png") ``` One cheeseburger monster coming up! Enjoy! @@ -164,22 +161,15 @@ prompt = "A fantasy landscape, Cinematic lighting" negative_prompt = "low quality, bad quality" generator = torch.Generator(device="cuda").manual_seed(30) -image_emb = pipe_prior( - prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt -).images - -zero_image_emb = pipe_prior( - negative_prompt, guidance_scale=4.0, num_inference_steps=25, generator=generator, negative_prompt=negative_prompt -).images +image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple() out = pipe( prompt, image=original_image, - image_embeds=image_emb, - negative_image_embeds=zero_image_emb, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, height=768, width=768, - num_inference_steps=500, strength=0.3, ) @@ -193,7 +183,7 @@ out.images[0].save("fantasy_land.png") You can use [`KandinskyInpaintPipeline`] to edit images. In this example, we will add a hat to the portrait of a cat. -```python +```py from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline from diffusers.utils import load_image import torch @@ -205,7 +195,7 @@ pipe_prior = KandinskyPriorPipeline.from_pretrained( pipe_prior.to("cuda") prompt = "a hat" -image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) +prior_output = pipe_prior(prompt) pipe = KandinskyInpaintPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16) pipe.to("cuda") @@ -222,8 +212,7 @@ out = pipe( prompt, image=init_image, mask_image=mask, - image_embeds=image_emb, - negative_image_embeds=zero_image_emb, + **prior_output, height=768, width=768, num_inference_steps=150, @@ -246,7 +235,6 @@ from diffusers.utils import load_image import PIL import torch -from torchvision import transforms pipe_prior = KandinskyPriorPipeline.from_pretrained( "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 @@ -263,22 +251,80 @@ img2 = load_image( # add all the conditions we want to interpolate, can be either text or image images_texts = ["a cat", img1, img2] + # specify the weights for each condition in images_texts weights = [0.3, 0.3, 0.4] -image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + +# We can leave the prompt empty +prompt = "" +prior_out = pipe_prior.interpolate(images_texts, weights) pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) pipe.to("cuda") -image = pipe( - "", image_embeds=image_emb, negative_image_embeds=zero_image_emb, height=768, width=768, num_inference_steps=150 -).images[0] +image = pipe(prompt, **prior_out, height=768, width=768).images[0] image.save("starry_cat.png") ``` ![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png) +## Optimization + +Running Kandinsky in inference requires running both a first prior pipeline: [`KandinskyPriorPipeline`] +and a second image decoding pipeline which is one of [`KandinskyPipeline`], [`KandinskyImg2ImgPipeline`], or [`KandinskyInpaintPipeline`]. + +The bulk of the computation time will always be the second image decoding pipeline, so when looking +into optimizing the model, one should look into the second image decoding pipeline. + +When running with PyTorch < 2.0, we strongly recommend making use of [`xformers`](https://github.com/facebookresearch/xformers) +to speed-up the optimization. This can be done by simply running: + +```py +from diffusers import DiffusionPipeline +import torch + +t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +t2i_pipe.enable_xformers_memory_efficient_attention() +``` + +When running on PyTorch >= 2.0, PyTorch's SDPA attention will automatically be used. For more information on +PyTorch's SDPA, feel free to have a look at [this blog post](https://pytorch.org/blog/accelerated-diffusers-pt-20/). + +To have explicit control , you can also manually set the pipeline to use PyTorch's 2.0 efficient attention: + +```py +from diffusers.models.attention_processor import AttnAddedKVProcessor2_0 + +t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor2_0()) +``` + +The slowest and most memory intense attention processor is the default `AttnAddedKVProcessor` processor. +We do **not** recommend using it except for testing purposes or cases where very high determistic behaviour is desired. +You can set it with: + +```py +from diffusers.models.attention_processor import AttnAddedKVProcessor + +t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor()) +``` + +With PyTorch >= 2.0, you can also use Kandinsky with `torch.compile` which depending +on your hardware can signficantly speed-up your inference time once the model is compiled. +To use Kandinsksy with `torch.compile`, you can do: + +```py +t2i_pipe.unet.to(memory_format=torch.channels_last) +t2i_pipe.unet = torch.compile(t2i_pipe.unet, mode="reduce-overhead", fullgraph=True) +``` + +After compilation you should see a very fast inference time. For more information, +feel free to have a look at [Our PyTorch 2.0 benchmark](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0). + + + + + ## KandinskyPriorPipeline [[autodoc]] KandinskyPriorPipeline @@ -292,15 +338,14 @@ image.save("starry_cat.png") - all - __call__ -## KandinskyInpaintPipeline - -[[autodoc]] KandinskyInpaintPipeline - - all - - __call__ - ## KandinskyImg2ImgPipeline [[autodoc]] KandinskyImg2ImgPipeline - all - __call__ +## KandinskyInpaintPipeline + +[[autodoc]] KandinskyInpaintPipeline + - all + - __call__ diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index 29545bd88d..0da9d205f8 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -304,12 +304,12 @@ class KandinskyPipeline(DiffusionPipeline): prompt: Union[str, List[str]], image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, - negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -325,6 +325,9 @@ class KandinskyPipeline(DiffusionPipeline): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): @@ -340,9 +343,6 @@ class KandinskyPipeline(DiffusionPipeline): usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -418,7 +418,8 @@ class KandinskyPipeline(DiffusionPipeline): timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, - ).sample + return_dict=False, + )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py index 470fa606af..f32528617e 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -368,13 +368,13 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], image_embeds: torch.FloatTensor, negative_image_embeds: torch.FloatTensor, + negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, strength: float = 0.3, guidance_scale: float = 7.0, num_images_per_prompt: int = 1, - negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, @@ -392,6 +392,9 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): @@ -413,9 +416,6 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -512,7 +512,8 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline): timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, - ).sample + return_dict=False, + )[0] if do_classifier_free_guidance: noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py index cc9a35e580..04810ddb6e 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -466,12 +466,12 @@ class KandinskyInpaintPipeline(DiffusionPipeline): mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], image_embeds: torch.FloatTensor, negative_image_embeds: torch.FloatTensor, + negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 512, num_inference_steps: int = 100, guidance_scale: float = 4.0, num_images_per_prompt: int = 1, - negative_prompt: Optional[Union[str, List[str]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -498,6 +498,9 @@ class KandinskyInpaintPipeline(DiffusionPipeline): The clip image embeddings for text prompt, that will be used to condition the image generation. negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): @@ -513,9 +516,6 @@ class KandinskyInpaintPipeline(DiffusionPipeline): usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. @@ -629,7 +629,8 @@ class KandinskyInpaintPipeline(DiffusionPipeline): timestep=t, encoder_hidden_states=text_encoder_hidden_states, added_cond_kwargs=added_cond_kwargs, - ).sample + return_dict=False, + )[0] if do_classifier_free_guidance: noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py index d9474b43da..0c262c57ab 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -116,14 +116,14 @@ class KandinskyPriorPipelineOutput(BaseOutput): Output class for KandinskyPriorPipeline. Args: - images (`torch.FloatTensor`) + image_embeds (`torch.FloatTensor`) clip image embeddings for text prompt - zero_embeds (`List[PIL.Image.Image]` or `np.ndarray`) + negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) clip image embeddings for unconditional tokens """ - images: Union[torch.FloatTensor, np.ndarray] - zero_embeds: Union[torch.FloatTensor, np.ndarray] + image_embeds: Union[torch.FloatTensor, np.ndarray] + negative_image_embeds: Union[torch.FloatTensor, np.ndarray] class KandinskyPriorPipeline(DiffusionPipeline): @@ -231,7 +231,7 @@ class KandinskyPriorPipeline(DiffusionPipeline): image_embeddings = [] for cond, weight in zip(images_and_prompts, weights): if isinstance(cond, str): - image_emb = self.__call__( + image_emb = self( cond, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, @@ -239,7 +239,7 @@ class KandinskyPriorPipeline(DiffusionPipeline): latents=latents, negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, - ).images + ).image_embeds elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): if isinstance(cond, PIL.Image.Image): @@ -261,7 +261,7 @@ class KandinskyPriorPipeline(DiffusionPipeline): image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) - out_zero = self.__call__( + out_zero = self( negative_prompt, num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, @@ -270,9 +270,9 @@ class KandinskyPriorPipeline(DiffusionPipeline): negative_prompt=negative_prior_prompt, guidance_scale=guidance_scale, ) - zero_image_emb = out_zero.zero_embeds if negative_prompt == "" else out_zero.images + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds - return image_emb, zero_image_emb + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: @@ -435,11 +435,11 @@ class KandinskyPriorPipeline(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: int = 1, num_inference_steps: int = 25, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - negative_prompt: Optional[Union[str, List[str]]] = None, guidance_scale: float = 4.0, output_type: Optional[str] = "pt", # pt only return_dict: bool = True, @@ -450,6 +450,9 @@ class KandinskyPriorPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. num_inference_steps (`int`, *optional*, defaults to 100): @@ -462,9 +465,6 @@ class KandinskyPriorPipeline(DiffusionPipeline): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored - if `guidance_scale` is less than `1`). guidance_scale (`float`, *optional*, defaults to 4.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -484,14 +484,24 @@ class KandinskyPriorPipeline(DiffusionPipeline): """ if isinstance(prompt, str): - batch_size = 1 - elif isinstance(prompt, list): - batch_size = len(prompt) - else: + prompt = [prompt] + elif not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + device = self._execution_device + batch_size = len(prompt) batch_size = batch_size * num_images_per_prompt do_classifier_free_guidance = guidance_scale > 1.0 @@ -548,7 +558,12 @@ class KandinskyPriorPipeline(DiffusionPipeline): latents = self.prior.post_process_latents(latents) image_embeddings = latents - zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) if output_type not in ["pt", "np"]: raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") @@ -560,4 +575,4 @@ class KandinskyPriorPipeline(DiffusionPipeline): if not return_dict: return (image_embeddings, zero_embeds) - return KandinskyPriorPipelineOutput(images=image_embeddings, zero_embeds=zero_embeds) + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py index 8f7d5ae201..239433910b 100644 --- a/tests/pipelines/kandinsky/test_kandinsky.py +++ b/tests/pipelines/kandinsky/test_kandinsky.py @@ -258,12 +258,12 @@ class KandinskyPipelineIntegrationTests(unittest.TestCase): prompt = "red cat, 4k photo" generator = torch.Generator(device="cuda").manual_seed(0) - image_emb = pipe_prior( + image_emb, zero_image_emb = pipe_prior( prompt, generator=generator, num_inference_steps=5, - ).images - zero_image_emb = pipe_prior("", num_inference_steps=5).images + negative_prompt="", + ).to_tuple() generator = torch.Generator(device="cuda").manual_seed(0) output = pipeline( diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py index 6958403ae1..94817b3eed 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py +++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py @@ -276,12 +276,12 @@ class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase): pipeline.set_progress_bar_config(disable=None) generator = torch.Generator(device="cpu").manual_seed(0) - image_emb = pipe_prior( + image_emb, zero_image_emb = pipe_prior( prompt, generator=generator, num_inference_steps=5, - ).images - zero_image_emb = pipe_prior("", num_inference_steps=5).images + negative_prompt="", + ).to_tuple() output = pipeline( prompt, diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py index 1bca753bec..46926479ae 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py @@ -286,12 +286,12 @@ class KandinskyInpaintPipelineIntegrationTests(unittest.TestCase): pipeline.set_progress_bar_config(disable=None) generator = torch.Generator(device="cpu").manual_seed(0) - image_emb = pipe_prior( + image_emb, zero_image_emb = pipe_prior( prompt, generator=generator, num_inference_steps=5, - ).images - zero_image_emb = pipe_prior("").images + negative_prompt="", + ).to_tuple() output = pipeline( prompt, diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py index 5ed1f2ac98..d9c260eabc 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py @@ -194,7 +194,7 @@ class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipe.set_progress_bar_config(disable=None) output = pipe(**self.get_dummy_inputs(device)) - image = output.images + image = output.image_embeds image_from_tuple = pipe( **self.get_dummy_inputs(device), diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 3ddfd35def..8ce0a0f283 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -650,7 +650,7 @@ class PipelineTesterMixin: if key in self.batch_params: inputs[key] = batch_size * [inputs[key]] - images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] assert images.shape[0] == batch_size * num_images_per_prompt