mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add SDXL refiner only tests (#5041)
* add refiner only tests * make style
This commit is contained in:
@@ -26,7 +26,12 @@ from diffusers import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
require_torch_gpu,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
@@ -159,24 +164,6 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_xl_refiner(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components(skip_first_text_encoder=True)
|
||||
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
|
||||
|
||||
@@ -195,7 +182,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward without prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
negative_prompt = 3 * ["this is a negative prompt"]
|
||||
inputs["negative_prompt"] = negative_prompt
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
@@ -204,7 +192,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with prompt embeds
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
negative_prompt = 3 * ["this is a negative prompt"]
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
@@ -248,7 +237,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
for pipe in pipes:
|
||||
pipe.unet.set_default_attn_processor()
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
@@ -261,13 +251,15 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
sd_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
|
||||
# forward with single prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
inputs["prompt_2"] = inputs["prompt"]
|
||||
output = sd_pipe(**inputs)
|
||||
@@ -277,7 +269,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
inputs["prompt_2"] = "different prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
@@ -287,14 +280,16 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
# manually set a negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with same negative_prompt duplicated
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = inputs["negative_prompt"]
|
||||
@@ -305,7 +300,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
# forward with different negative_prompt
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["num_inference_steps"] = 5
|
||||
inputs["negative_prompt"] = "negative prompt"
|
||||
inputs["negative_prompt_2"] = "different negative prompt"
|
||||
@@ -342,3 +338,229 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max()
|
||||
> 1e-4
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionXLImg2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"tokenizer": None,
|
||||
"text_encoder": None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"strength": 0.8,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_xl_img2img_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4745, 0.4924, 0.4338, 0.6468, 0.5547, 0.4419, 0.5646, 0.5897, 0.5146])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_xl_offloads(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe.enable_model_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe.enable_sequential_cpu_offload()
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
pipe.unet.set_default_attn_processor()
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_xl_img2img_negative_conditions(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice_with_no_neg_conditions = image[0, -3:, -3:, -1]
|
||||
|
||||
image = sd_pipe(
|
||||
**inputs,
|
||||
negative_original_size=(512, 512),
|
||||
negative_crops_coords_top_left=(
|
||||
0,
|
||||
0,
|
||||
),
|
||||
negative_target_size=(1024, 1024),
|
||||
).images
|
||||
image_slice_with_neg_conditions = image[0, -3:, -3:, -1]
|
||||
|
||||
assert (
|
||||
np.abs(image_slice_with_no_neg_conditions.flatten() - image_slice_with_neg_conditions.flatten()).max()
|
||||
> 1e-4
|
||||
)
|
||||
|
||||
def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self):
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# forward without prompt embeds
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
negative_prompt = 3 * ["this is a negative prompt"]
|
||||
inputs["negative_prompt"] = negative_prompt
|
||||
inputs["prompt"] = 3 * [inputs["prompt"]]
|
||||
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice_1 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# forward with prompt embeds
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
negative_prompt = 3 * ["this is a negative prompt"]
|
||||
prompt = 3 * [inputs.pop("prompt")]
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = sd_pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
|
||||
|
||||
output = sd_pipe(
|
||||
**inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
)
|
||||
image_slice_2 = output.images[0, -3:, -3:, -1]
|
||||
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
Reference in New Issue
Block a user