diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 6bdc281ef8..6cd1658c59 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -603,15 +603,6 @@ class StableDiffusionControlNetPipeline( f" {negative_prompt_embeds.shape}." ) - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -633,7 +624,13 @@ class StableDiffusionControlNetPipeline( # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) elif len(image) != len(self.controlnet.nets): raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." @@ -659,7 +656,10 @@ class StableDiffusionControlNetPipeline( ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings is not supported at the moment.") + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): @@ -906,7 +906,9 @@ class StableDiffusionControlNetPipeline( accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, images must be passed as a list such that each element of the list can be correctly batched for - input to a single ControlNet. + input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet, + each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets, + where a list of image lists can be passed to batch for each prompt and each ControlNet. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -1105,6 +1107,11 @@ class StableDiffusionControlNetPipeline( elif isinstance(controlnet, MultiControlNetModel): images = [] + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + for image_ in image: image_ = self.prepare_image( image=image_, diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index ce86933430..c034a9b68b 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -460,6 +460,33 @@ class StableDiffusionMultiControlNetPipelineFastTests( except NotImplementedError: pass + def test_inference_multiple_prompt_input(self): + device = "cpu" + + components = self.get_dummy_components() + sd_pipe = StableDiffusionControlNetPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + inputs["image"] = [inputs["image"], inputs["image"]] + output = sd_pipe(**inputs) + image = output.images + + assert image.shape == (2, 64, 64, 3) + + image_1, image_2 = image + # make sure that the outputs are different + assert np.sum(np.abs(image_1 - image_2)) > 1e-3 + + # multiple prompts, single image conditioning + inputs = self.get_dummy_inputs(device) + inputs["prompt"] = [inputs["prompt"], inputs["prompt"]] + output_1 = sd_pipe(**inputs) + + assert np.abs(image - output_1.images).max() < 1e-3 + class StableDiffusionMultiControlNetOneModelPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase