1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix bug in StableDiffusionXLControlNetPipeline when use guess_mode (#4799)

* fix



---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
YiYi Xu
2023-08-28 06:51:17 -10:00
committed by GitHub
parent e3f3672f46
commit 934d439a42
2 changed files with 30 additions and 2 deletions

View File

@@ -1104,15 +1104,22 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
controlnet_added_cond_kwargs = {
"text_embeds": add_text_embeds.chunk(2)[1],
"time_ids": add_time_ids.chunk(2)[1],
}
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
controlnet_added_cond_kwargs = added_cond_kwargs
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
@@ -1122,7 +1129,6 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
@@ -1130,7 +1136,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=added_cond_kwargs,
added_cond_kwargs=controlnet_added_cond_kwargs,
return_dict=False,
)

View File

@@ -300,6 +300,28 @@ class StableDiffusionXLControlNetPipelineFastTests(
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_controlnet_sdxl_guess(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["guess_mode"] = True
output = sd_pipe(**inputs)
image_slice = output.images[0, -3:, -3:, -1]
expected_slice = np.array(
[0.7330834, 0.590667, 0.5667336, 0.6029023, 0.5679491, 0.5968194, 0.4032986, 0.47612396, 0.5089609]
)
# make sure that it's equal
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
class StableDiffusionXLMultiControlNetPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase