mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] Support negative conditions in SDXL (#4774)
* add: support negative conditions. * fix: key * add: tests * address PR feedback. * add documentation * add img2img support. * add inpainting support. * ad controlnet support * Apply suggestions from code review * modify wording in the doc.
This commit is contained in:
@@ -23,6 +23,7 @@ The abstract of the paper is the following:
|
||||
- Stable Diffusion XL works especially well with images between 768 and 1024.
|
||||
- Stable Diffusion XL can pass a different prompt for each of the text encoders it was trained on as shown below. We can even pass different parts of the same prompt to the text encoders.
|
||||
- Stable Diffusion XL output image can be improved by making use of a refiner as shown below.
|
||||
- One can make use of `negative_original_size`, `negative_crops_coords_top_left`, and `negative_target_size` to influence the generation process.
|
||||
|
||||
### Available checkpoints:
|
||||
|
||||
@@ -74,6 +75,37 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(prompt=prompt).images[0]
|
||||
```
|
||||
|
||||
You can additionally pass negative conditions about an image's size and position to avoid undesirable cropping behavior in the generated image, and improve image resolution. Let's take an example:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_original_size=(512, 512),
|
||||
negative_crops_coords_top_left=(0, 0),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Here is a comparative example that shows the influence of using three `negative_original_size`s of
|
||||
(128, 128), (256, 256), and (512, 512) respectively:
|
||||
|
||||

|
||||
|
||||
<Tip>
|
||||
|
||||
One can use these negative conditions in the other SDXL pipelines ([Image-To-Image](#image-to-image), [Inpainting](#inpainting), [ControlNet](../controlnet_sdxl.md)) too!
|
||||
|
||||
</Tip>
|
||||
|
||||
### Image-to-image
|
||||
|
||||
You can use SDXL as follows for *image-to-image*:
|
||||
|
||||
@@ -789,6 +789,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -895,6 +898,22 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -1058,10 +1077,20 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
|
||||
@@ -589,6 +589,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -688,6 +691,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -783,11 +801,20 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
|
||||
@@ -601,14 +601,25 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
|
||||
self,
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
|
||||
add_neg_time_ids = list(
|
||||
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
|
||||
)
|
||||
else:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
@@ -690,6 +701,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
aesthetic_score: float = 6.0,
|
||||
negative_aesthetic_score: float = 2.5,
|
||||
):
|
||||
@@ -804,6 +818,21 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
aesthetic_score (`float`, *optional*, defaults to 6.0):
|
||||
Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
|
||||
Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
@@ -908,6 +937,11 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 8. Prepare added time ids & embeddings
|
||||
if negative_original_size is None:
|
||||
negative_original_size = original_size
|
||||
if negative_target_size is None:
|
||||
negative_target_size = target_size
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
@@ -915,6 +949,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
@@ -813,14 +813,25 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype
|
||||
self,
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype,
|
||||
):
|
||||
if self.config.requires_aesthetics_score:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
|
||||
add_neg_time_ids = list(
|
||||
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
|
||||
)
|
||||
else:
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
|
||||
@@ -905,6 +916,9 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
|
||||
original_size: Tuple[int, int] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Tuple[int, int] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
aesthetic_score: float = 6.0,
|
||||
negative_aesthetic_score: float = 2.5,
|
||||
):
|
||||
@@ -1025,6 +1039,21 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
aesthetic_score (`float`, *optional*, defaults to 6.0):
|
||||
Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
|
||||
Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
@@ -1199,6 +1228,11 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
if negative_original_size is None:
|
||||
negative_original_size = original_size
|
||||
if negative_target_size is None:
|
||||
negative_target_size = target_size
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
@@ -1206,6 +1240,9 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
|
||||
target_size,
|
||||
aesthetic_score,
|
||||
negative_aesthetic_score,
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
@@ -160,7 +160,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
"image": image,
|
||||
}
|
||||
|
||||
@@ -680,6 +680,25 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
|
||||
|
||||
def test_negative_conditions(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_without_neg_cond = image[0, -3:, -3:, -1]
|
||||
|
||||
image = pipe(
|
||||
**inputs,
|
||||
negative_original_size=(512, 512),
|
||||
negative_crops_coords_top_left=(0, 0),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images
|
||||
image_slice_with_neg_cond = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -129,7 +129,7 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -691,6 +691,27 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
def test_stable_diffusion_xl_negative_conditions(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_with_no_neg_cond = image[0, -3:, -3:, -1]
|
||||
|
||||
image = sd_pipe(
|
||||
**inputs,
|
||||
negative_original_size=(512, 512),
|
||||
negative_crops_coords_top_left=(0, 0),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images
|
||||
image_slice_with_neg_cond = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(np.abs(image_slice_with_no_neg_cond - image_slice_with_neg_cond).max() > 1e-2)
|
||||
|
||||
def test_stable_diffusion_xl_save_from_pretrained(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
|
||||
@@ -138,7 +138,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
"strength": 0.8,
|
||||
}
|
||||
return inputs
|
||||
@@ -315,3 +315,31 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
def test_stable_diffusion_xl_img2img_negative_conditions(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_with_no_neg_conditions = image[0, -3:, -3:, -1]
|
||||
|
||||
image = sd_pipe(
|
||||
**inputs,
|
||||
negative_original_size=(512, 512),
|
||||
negative_crops_coords_top_left=(
|
||||
0,
|
||||
0,
|
||||
),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images
|
||||
image_slice_with_neg_conditions = image[0, -3:, -3:, -1]
|
||||
|
||||
assert (
|
||||
np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max()
|
||||
> 1e-4
|
||||
)
|
||||
|
||||
@@ -139,7 +139,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@@ -471,3 +471,30 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
def test_stable_diffusion_xl_img2img_negative_conditions(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_with_no_neg_conditions = image[0, -3:, -3:, -1]
|
||||
|
||||
image = sd_pipe(
|
||||
**inputs,
|
||||
negative_original_size=(512, 512),
|
||||
negative_crops_coords_top_left=(
|
||||
0,
|
||||
0,
|
||||
),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images
|
||||
image_slice_with_neg_conditions = image[0, -3:, -3:, -1]
|
||||
|
||||
assert (
|
||||
np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max()
|
||||
> 1e-4
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user