From 1040dfd9ccd1adcc5ef98c86e7c9f98812840c46 Mon Sep 17 00:00:00 2001 From: Celestial Phineas <17267055+celestialphineas@users.noreply.github.com> Date: Wed, 17 Jan 2024 04:40:55 +0800 Subject: [PATCH] [Fix] Multiple image conditionings in a single batch for `StableDiffusionControlNetPipeline` (#6334) * [Fix] Multiple image conditionings in a single batch for `StableDiffusionControlNetPipeline`. * Refactor `check_inputs` in `StableDiffusionControlNetPipeline` to avoid redundant codes. * Make the behavior of MultiControlNetModel to be the same to the original ControlNetModel * Keep the code change minimum for nested list support * Add fast test `test_inference_nested_image_input` * Remove redundant check for nested image condition in `check_inputs` Remove `len(image) == len(prompt)` check out of `check_image()` Co-authored-by: YiYi Xu * Better `ValueError` message for incompatible nested image list size Co-authored-by: YiYi Xu * Fix syntax error in `check_inputs` * Remove warning message for multi-ControlNets with multiple prompts * Fix a typo in test_controlnet.py Co-authored-by: YiYi Xu * Add test case for multiple prompts, single image conditioning in `StableDiffusionMultiControlNetPipelineFastTests` * Improved `ValueError` message for nested `controlnet_conditioning_scale` * Documenting the behavior of image list as `StableDiffusionControlNetPipeline` input --------- Co-authored-by: YiYi Xu --- .../controlnet/pipeline_controlnet.py | 31 ++++++++++++------- tests/pipelines/controlnet/test_controlnet.py | 27 ++++++++++++++++ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 6bdc281ef8..6cd1658c59 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -603,15 +603,6 @@ class StableDiffusionControlNetPipeline( f" {negative_prompt_embeds.shape}." ) - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -633,7 +624,13 @@ class StableDiffusionControlNetPipeline( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -659,7 +656,10 @@ class StableDiffusionControlNetPipeline( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): @@ -906,7 +906,9 @@ class StableDiffusionControlNetPipeline( accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. + input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, + each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, + where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -1105,6 +1107,11 @@ class StableDiffusionControlNetPipeline( elif isinstance(controlnet, MultiControlNetModel): images = [] + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + for image_ in image: image_ = self.prepare_image( image=image_, diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index ce86933430..c034a9b68b 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -460,6 +460,33 @@ class StableDiffusionMultiControlNetPipelineFastTests( except NotImplementedError: pass + def test_inference_multiple_prompt_input(self): + device = "cpu" + + components = self.get_dummy_components() + sd_pipe = StableDiffusionControlNetPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + inputs["image"] = [inputs["image"], inputs["image"]] + output = sd_pipe(**inputs) + image = output.images + + assert image.shape == (2, 64, 64, 3) + + image_1, image_2 = image + # make sure that the outputs are different + assert np.sum(np.abs(image_1 - image_2)) > 1e-3 + + # multiple prompts, single image conditioning + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + output_1 = sd_pipe(**inputs) + + assert np.abs(image - output_1.images).max() < 1e-3 + class StableDiffusionMultiControlNetOneModelPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase