1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow passing different prompts to each text_encoder on stable_diffusion_xl pipelines (#4156)

* sdxl prompt2

* Improve checks

* doc linting

* whoops

* remove cat

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Add other pipelines and tests

* Add multi-prompting to docs

* doc and copies check

* Fix copied froms

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Bring back the original code for unrelated files

* Fix tests

* Fix img2img

* Fix all

* fix

---------

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
apolinário
2023-07-21 14:50:22 +02:00
committed by GitHub
parent e2bbaa4f54
commit aed30dff6b
9 changed files with 469 additions and 68 deletions

View File

@@ -21,6 +21,7 @@ The abstract of the paper is the following:
## Tips
- Stable Diffusion XL works especially well with images between 768 and 1024.
- Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders.
- Stable Diffusion XL output image can be improved by making use of a refiner as shown below.
### Available checkpoints:
@@ -362,3 +363,25 @@ pip install xformers
[[autodoc]] StableDiffusionXLInpaintPipeline
- all
- __call__
### Passing different prompts to each text-encoder
Stable Diffusion XL was trained on two text encoders. The default behavior is to pass the same prompt to each. But it is possible to pass a different prompt for each text-encoder, as [some users](https://github.com/huggingface/diffusers/issues/4004#issuecomment-1627764201) noted that it can boost quality.
To do so, you can pass `prompt_2` and `negative_prompt_2` in addition to `prompt` and `negative_prompt`. By doing that, you will pass the original prompts and negative prompts (as in `prompt` and `negative_prompt`) to `text_encoder` (in official SDXL 0.9/1.0 that is [OpenAI CLIP-ViT/L-14](https://huggingface.co/openai/clip-vit-large-patch14)),
and `prompt_2` and `negative_prompt_2` to `text_encoder_2` (in official SDXL 0.9/1.0 that is [OpenCLIP-ViT/bigG-14](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
```py
from diffusers import StableDiffusionXLPipeline
import torch
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe.to("cuda")
# prompt will be passed to OAI CLIP-ViT/L-14
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# prompt_2 will be passed to OpenCLIP-ViT/bigG-14
prompt_2 = "monet painting"
image = pipe(prompt=prompt, prompt_2=prompt_2).images[0]
```

View File

@@ -196,11 +196,13 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
self,
prompt,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt=None,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -211,8 +213,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
@@ -223,6 +228,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -261,9 +269,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -274,8 +284,10 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
@@ -311,6 +323,8 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
@@ -318,7 +332,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -326,17 +340,16 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
@@ -401,9 +414,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
def check_inputs(
self,
prompt,
prompt_2,
image,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
@@ -423,18 +438,30 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -610,6 +637,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
@@ -623,6 +651,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -649,6 +678,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
@@ -674,6 +706,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
@@ -749,9 +784,11 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
image,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
@@ -791,10 +828,12 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
prompt_2,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,

View File

@@ -211,11 +211,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
def encode_prompt(
self,
prompt,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt=None,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -226,8 +228,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
@@ -238,6 +243,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -276,9 +284,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -289,8 +299,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
@@ -326,6 +338,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
@@ -333,7 +347,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -341,17 +355,16 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
@@ -416,10 +429,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
@@ -441,18 +456,30 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -531,12 +558,14 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -562,6 +591,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -587,6 +619,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
@@ -660,10 +695,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
@@ -695,11 +732,13 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,

View File

@@ -219,11 +219,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
self,
prompt,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt=None,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -234,8 +236,11 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
@@ -246,6 +251,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -284,9 +292,11 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -297,8 +307,10 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
@@ -334,6 +346,8 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
@@ -341,7 +355,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -349,17 +363,16 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
@@ -424,10 +437,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
def check_inputs(
self,
prompt,
prompt_2,
strength,
num_inference_steps,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
@@ -453,18 +468,30 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -617,6 +644,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
@@ -631,6 +659,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -658,6 +687,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`):
The image(s) to modify with the pipeline.
strength (`float`, *optional*, defaults to 0.3):
@@ -697,6 +729,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
@@ -767,10 +802,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
strength,
num_inference_steps,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
)
@@ -800,11 +837,13 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,

View File

@@ -325,11 +325,13 @@ class StableDiffusionXLInpaintPipeline(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
self,
prompt,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt=None,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -340,8 +342,11 @@ class StableDiffusionXLInpaintPipeline(
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
@@ -352,6 +357,9 @@ class StableDiffusionXLInpaintPipeline(
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -390,9 +398,11 @@ class StableDiffusionXLInpaintPipeline(
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
# textual inversion: procecss multi-vector tokens if necessary
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
@@ -403,8 +413,10 @@ class StableDiffusionXLInpaintPipeline(
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
@@ -440,6 +452,8 @@ class StableDiffusionXLInpaintPipeline(
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
@@ -447,7 +461,7 @@ class StableDiffusionXLInpaintPipeline(
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
uncond_tokens = [negative_prompt, negative_prompt_2]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -455,17 +469,16 @@ class StableDiffusionXLInpaintPipeline(
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
# textual inversion: procecss multi-vector tokens if necessary
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
uncond_tokens,
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
@@ -527,15 +540,16 @@ class StableDiffusionXLInpaintPipeline(
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.check_inputs
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
strength,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
@@ -558,18 +572,30 @@ class StableDiffusionXLInpaintPipeline(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -785,6 +811,7 @@ class StableDiffusionXLInpaintPipeline(
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None,
@@ -795,6 +822,7 @@ class StableDiffusionXLInpaintPipeline(
denoising_end: Optional[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -822,6 +850,9 @@ class StableDiffusionXLInpaintPipeline(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
@@ -868,6 +899,13 @@ class StableDiffusionXLInpaintPipeline(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -894,13 +932,6 @@ class StableDiffusionXLInpaintPipeline(
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -962,11 +993,13 @@ class StableDiffusionXLInpaintPipeline(
# 1. Check inputs
self.check_inputs(
prompt,
prompt_2,
height,
width,
strength,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
)
@@ -996,11 +1029,13 @@ class StableDiffusionXLInpaintPipeline(
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,

View File

@@ -15,6 +15,7 @@
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@@ -177,3 +178,56 @@ class ControlNetPipelineSDXLFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4

View File

@@ -355,3 +355,56 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
HeunDiscreteScheduler,
]:
assert_run_mixture(steps, split_1, split_2, scheduler_cls)
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4

View File

@@ -113,8 +113,6 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
# "safety_checker": None,
# "feature_extractor": None,
}
return components
@@ -132,7 +130,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "numpy",
"strength": 0.75,
"strength": 0.8,
}
return inputs
@@ -231,3 +229,62 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4

View File

@@ -123,7 +123,10 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
# create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
@@ -152,7 +155,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4924, 0.4966, 0.4100, 0.5233, 0.5322, 0.4532, 0.5804, 0.5876, 0.4150])
expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -367,3 +370,62 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
HeunDiscreteScheduler,
]:
assert_run_mixture(steps, split_1, split_2, scheduler_cls)
def test_stable_diffusion_xl_multi_prompts(self):
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
# forward with single prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = inputs["prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["prompt_2"] = "different prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
# manually set a negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
output = sd_pipe(**inputs)
image_slice_1 = output.images[0, -3:, -3:, -1]
# forward with same negative_prompt duplicated
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = inputs["negative_prompt"]
output = sd_pipe(**inputs)
image_slice_2 = output.images[0, -3:, -3:, -1]
# ensure the results are equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
# forward with different negative_prompt
inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = 5
inputs["negative_prompt"] = "negative prompt"
inputs["negative_prompt_2"] = "different negative prompt"
output = sd_pipe(**inputs)
image_slice_3 = output.images[0, -3:, -3:, -1]
# ensure the results are not equal
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4