mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Fix] Multiple image conditionings in a single batch for StableDiffusionControlNetPipeline (#6334)
* [Fix] Multiple image conditionings in a single batch for `StableDiffusionControlNetPipeline`. * Refactor `check_inputs` in `StableDiffusionControlNetPipeline` to avoid redundant codes. * Make the behavior of MultiControlNetModel to be the same to the original ControlNetModel * Keep the code change minimum for nested list support * Add fast test `test_inference_nested_image_input` * Remove redundant check for nested image condition in `check_inputs` Remove `len(image) == len(prompt)` check out of `check_image()` Co-authored-by: YiYi Xu <yixu310@gmail.com> * Better `ValueError` message for incompatible nested image list size Co-authored-by: YiYi Xu <yixu310@gmail.com> * Fix syntax error in `check_inputs` * Remove warning message for multi-ControlNets with multiple prompts * Fix a typo in test_controlnet.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Add test case for multiple prompts, single image conditioning in `StableDiffusionMultiControlNetPipelineFastTests` * Improved `ValueError` message for nested `controlnet_conditioning_scale` * Documenting the behavior of image list as `StableDiffusionControlNetPipeline` input --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
committed by
GitHub
parent
49a4b377c1
commit
1040dfd9cc
@@ -603,15 +603,6 @@ class StableDiffusionControlNetPipeline(
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
# `prompt` needs more sophisticated handling when there are multiple
|
||||
# conditionings.
|
||||
if isinstance(self.controlnet, MultiControlNetModel):
|
||||
if isinstance(prompt, list):
|
||||
logger.warning(
|
||||
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
||||
" prompts. The conditionings will be fixed across the prompts."
|
||||
)
|
||||
|
||||
# Check `image`
|
||||
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
||||
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
||||
@@ -633,7 +624,13 @@ class StableDiffusionControlNetPipeline(
|
||||
# When `image` is a nested list:
|
||||
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
|
||||
elif any(isinstance(i, list) for i in image):
|
||||
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
|
||||
transposed_image = [list(t) for t in zip(*image)]
|
||||
if len(transposed_image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
)
|
||||
for image_ in transposed_image:
|
||||
self.check_image(image_, prompt, prompt_embeds)
|
||||
elif len(image) != len(self.controlnet.nets):
|
||||
raise ValueError(
|
||||
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
||||
@@ -659,7 +656,10 @@ class StableDiffusionControlNetPipeline(
|
||||
):
|
||||
if isinstance(controlnet_conditioning_scale, list):
|
||||
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
|
||||
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
|
||||
raise ValueError(
|
||||
"A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. "
|
||||
"The conditioning scale must be fixed across the batch."
|
||||
)
|
||||
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
|
||||
self.controlnet.nets
|
||||
):
|
||||
@@ -906,7 +906,9 @@ class StableDiffusionControlNetPipeline(
|
||||
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
|
||||
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
|
||||
`init`, images must be passed as a list such that each element of the list can be correctly batched for
|
||||
input to a single ControlNet.
|
||||
input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
|
||||
each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
|
||||
where a list of image lists can be passed to batch for each prompt and each ControlNet.
|
||||
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
||||
@@ -1105,6 +1107,11 @@ class StableDiffusionControlNetPipeline(
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
|
||||
# Nested lists as ControlNet condition
|
||||
if isinstance(image[0], list):
|
||||
# Transpose the nested image list
|
||||
image = [list(t) for t in zip(*image)]
|
||||
|
||||
for image_ in image:
|
||||
image_ = self.prepare_image(
|
||||
image=image_,
|
||||
|
||||
@@ -460,6 +460,33 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
def test_inference_multiple_prompt_input(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionControlNetPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
|
||||
inputs["image"] = [inputs["image"], inputs["image"]]
|
||||
output = sd_pipe(**inputs)
|
||||
image = output.images
|
||||
|
||||
assert image.shape == (2, 64, 64, 3)
|
||||
|
||||
image_1, image_2 = image
|
||||
# make sure that the outputs are different
|
||||
assert np.sum(np.abs(image_1 - image_2)) > 1e-3
|
||||
|
||||
# multiple prompts, single image conditioning
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["prompt"] = [inputs["prompt"], inputs["prompt"]]
|
||||
output_1 = sd_pipe(**inputs)
|
||||
|
||||
assert np.abs(image - output_1.images).max() < 1e-3
|
||||
|
||||
|
||||
class StableDiffusionMultiControlNetOneModelPipelineFastTests(
|
||||
PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
|
||||
|
||||
Reference in New Issue
Block a user