mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix bug in StableDiffusionXLControlNetPipeline when use guess_mode (#4799)
* fix --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1104,15 +1104,22 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
controlnet_added_cond_kwargs = {
|
||||
"text_embeds": add_text_embeds.chunk(2)[1],
|
||||
"time_ids": add_time_ids.chunk(2)[1],
|
||||
}
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
controlnet_added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
@@ -1122,7 +1129,6 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
control_model_input,
|
||||
t,
|
||||
@@ -1130,7 +1136,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=cond_scale,
|
||||
guess_mode=guess_mode,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
added_cond_kwargs=controlnet_added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -300,6 +300,28 @@ class StableDiffusionXLControlNetPipelineFastTests(
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
|
||||
|
||||
def test_controlnet_sdxl_guess(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
sd_pipe = self.pipeline_class(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["guess_mode"] = True
|
||||
|
||||
output = sd_pipe(**inputs)
|
||||
image_slice = output.images[0, -3:, -3:, -1]
|
||||
expected_slice = np.array(
|
||||
[0.7330834, 0.590667, 0.5667336, 0.6029023, 0.5679491, 0.5968194, 0.4032986, 0.47612396, 0.5089609]
|
||||
)
|
||||
|
||||
# make sure that it's equal
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
|
||||
|
||||
|
||||
class StableDiffusionXLMultiControlNetPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
|
||||
Reference in New Issue
Block a user