mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SDXL Refiner] Fix refiner forward pass for batched input (#4327)
* fix_batch_xl * Fix other pipelines as well * up * up * Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py * sort * up * Finish it all up Co-authored-by: Bagheera <bghira@users.github.com> * Co-authored-by: Bagheera bghira@users.github.com * Co-authored-by: Bagheera <bghira@users.github.com> * Finish it all up Co-authored-by: Bagheera <bghira@users.github.com>
This commit is contained in:
committed by
Sayak Paul
parent
aa4634a7fa
commit
c63d7cdba0
@@ -906,15 +906,17 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
@@ -1168,15 +1168,17 @@ class StableDiffusionXLInpaintPipeline(
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 11. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
@@ -811,6 +811,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
negative_aesthetic_score,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
original_prompt_embeds_len = len(prompt_embeds)
|
||||
original_add_text_embeds_len = len(add_text_embeds)
|
||||
@@ -819,6 +820,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0)
|
||||
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)
|
||||
|
||||
# Make dimensions consistent
|
||||
@@ -828,7 +830,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device).to(torch.float32)
|
||||
add_text_embeds = add_text_embeds.to(device).to(torch.float32)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
add_time_ids = add_time_ids.to(device)
|
||||
|
||||
# 11. Denoising loop
|
||||
self.unet = self.unet.to(torch.float32)
|
||||
|
||||
@@ -64,7 +64,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64 if not skip_first_text_encoder else 32,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
@@ -113,9 +113,18 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"tokenizer": tokenizer if not skip_first_text_encoder else None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
@@ -147,7 +156,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165])
|
||||
expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -165,7 +174,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198])
|
||||
expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64 if not skip_first_text_encoder else 32,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
@@ -115,6 +115,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
"tokenizer": tokenizer if not skip_first_text_encoder else None,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -142,6 +143,14 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_stable_diffusion_xl_inpaint_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
@@ -155,7 +164,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
|
||||
expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -250,10 +259,9 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
print(torch.from_numpy(image_slice).flatten())
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418])
|
||||
expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
|
||||
@@ -118,8 +118,7 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
"tokenizer": tokenizer,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
# "safety_checker": None,
|
||||
# "feature_extractor": None,
|
||||
"requires_aesthetics_score": True,
|
||||
}
|
||||
return components
|
||||
|
||||
@@ -141,6 +140,14 @@ class StableDiffusionXLInstructPix2PixPipelineFastTests(
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
init_components.pop("requires_aesthetics_score")
|
||||
pipe = self.pipeline_class(**init_components)
|
||||
|
||||
self.assertTrue(hasattr(pipe, "components"))
|
||||
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user