1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Fix] Multiple image conditionings in a single batch for StableDiffusionControlNetPipeline (#6334)

* [Fix] Multiple image conditionings in a single batch for `StableDiffusionControlNetPipeline`.

* Refactor `check_inputs` in `StableDiffusionControlNetPipeline` to avoid redundant codes.

* Make the behavior of MultiControlNetModel to be the same to the original ControlNetModel

* Keep the code change minimum for nested list support

* Add fast test `test_inference_nested_image_input`

* Remove redundant check for nested image condition in `check_inputs`

Remove `len(image) == len(prompt)` check out of `check_image()`

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Better `ValueError` message for incompatible nested image list size

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Fix syntax error in `check_inputs`

* Remove warning message for multi-ControlNets with multiple prompts

* Fix a typo in test_controlnet.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Add test case for multiple prompts, single image conditioning in `StableDiffusionMultiControlNetPipelineFastTests`

* Improved `ValueError` message for nested `controlnet_conditioning_scale`

* Documenting the behavior of image list as `StableDiffusionControlNetPipeline` input

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Celestial Phineas
2024-01-17 04:40:55 +08:00
committed by GitHub
parent 49a4b377c1
commit 1040dfd9cc
2 changed files with 46 additions and 12 deletions

View File

@@ -603,15 +603,6 @@ class StableDiffusionControlNetPipeline(
f" {negative_prompt_embeds.shape}."
)
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
@@ -633,7 +624,13 @@ class StableDiffusionControlNetPipeline(
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
transposed_image = [list(t) for t in zip(*image)]
if len(transposed_image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
)
for image_ in transposed_image:
self.check_image(image_, prompt, prompt_embeds)
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
@@ -659,7 +656,10 @@ class StableDiffusionControlNetPipeline(
):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
raise ValueError(
"A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
"The conditioning scale must be fixed across the batch."
)
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
@@ -906,7 +906,9 @@ class StableDiffusionControlNetPipeline(
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
`init`, images must be passed as a list such that each element of the list can be correctly batched for
input to a single ControlNet.
input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
where a list of image lists can be passed to batch for each prompt and each ControlNet.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -1105,6 +1107,11 @@ class StableDiffusionControlNetPipeline(
elif isinstance(controlnet, MultiControlNetModel):
images = []
# Nested lists as ControlNet condition
if isinstance(image[0], list):
# Transpose the nested image list
image = [list(t) for t in zip(*image)]
for image_ in image:
image_ = self.prepare_image(
image=image_,

View File

@@ -460,6 +460,33 @@ class StableDiffusionMultiControlNetPipelineFastTests(
except NotImplementedError:
pass
def test_inference_multiple_prompt_input(self):
device = "cpu"
components = self.get_dummy_components()
sd_pipe = StableDiffusionControlNetPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
inputs["image"] = [inputs["image"], inputs["image"]]
output = sd_pipe(**inputs)
image = output.images
assert image.shape == (2, 64, 64, 3)
image_1, image_2 = image
# make sure that the outputs are different
assert np.sum(np.abs(image_1 - image_2)) > 1e-3
# multiple prompts, single image conditioning
inputs = self.get_dummy_inputs(device)
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
output_1 = sd_pipe(**inputs)
assert np.abs(image - output_1.images).max() < 1e-3
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase