diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index 2df8c7edd7..2214611d26 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -39,6 +39,7 @@ class MultiControlNetModel(ModelMixin): class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, @@ -53,6 +54,7 @@ class MultiControlNetModel(ModelMixin): class_labels=class_labels, timestep_cond=timestep_cond, attention_mask=attention_mask, + added_cond_kwargs=added_cond_kwargs, cross_attention_kwargs=cross_attention_kwargs, guess_mode=guess_mode, return_dict=return_dict, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index d06a49421e..c4a8777c5e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -149,7 +149,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, @@ -157,7 +157,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa super().__init__() if isinstance(controlnet, (list, tuple)): - raise ValueError("MultiControlNet is not yet supported.") + controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, @@ -530,6 +530,15 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule @@ -540,6 +549,25 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) else: assert False @@ -551,14 +579,41 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) else: assert False + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) + if isinstance(self.controlnet, MultiControlNetModel): + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( @@ -569,6 +624,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) @@ -606,6 +662,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, image, @@ -888,6 +945,9 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + global_pool_conditions = ( controlnet.config.global_pool_conditions if isinstance(controlnet, ControlNetModel) @@ -933,6 +993,26 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa guess_mode=guess_mode, ) height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] else: assert False @@ -963,12 +1043,15 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] - controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) - - original_size = original_size or image.shape[-2:] - target_size = target_size or (height, width) + controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + add_text_embeds = pooled_prompt_embeds add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index e9d64c89c6..0c7ee4972f 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -26,6 +26,7 @@ from diffusers import ( StableDiffusionXLControlNetPipeline, UNet2DConditionModel, ) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.utils import randn_tensor, torch_device from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu @@ -46,7 +47,7 @@ from ..test_pipelines_common import ( enable_full_determinism() -class ControlNetPipelineSDXLFastTests( +class StableDiffusionXLControlNetPipelineFastTests( PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLControlNetPipeline @@ -297,3 +298,383 @@ class ControlNetPipelineSDXLFastTests( # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + + +class StableDiffusionXLMultiControlNetPipelineFastTests( + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase +): + pipeline_class = StableDiffusionXLControlNetPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + + def init_weights(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.normal(m.weight) + m.bias.data.fill_(1.0) + + controlnet1 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + controlnet1.controlnet_down_blocks.apply(init_weights) + + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + controlnet2.controlnet_down_blocks.apply(init_weights) + + torch.manual_seed(0) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + + images = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": images, + } + + return inputs + + def test_control_guidance_switch(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + + scale = 10.0 + steps = 4 + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_1 = pipe(**inputs)[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_3 = pipe(**inputs, control_guidance_start=[0.1, 0.3], control_guidance_end=[0.2, 0.7])[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5, 0.8])[0] + + # make sure that all outputs are different + assert np.sum(np.abs(output_1 - output_2)) > 1e-3 + assert np.sum(np.abs(output_1 - output_3)) > 1e-3 + assert np.sum(np.abs(output_1 - output_4)) > 1e-3 + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( + PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): + pipeline_class = StableDiffusionXLControlNetPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + + def init_weights(m): + if isinstance(m, torch.nn.Conv2d): + torch.nn.init.normal(m.weight) + m.bias.data.fill_(1.0) + + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + controlnet.controlnet_down_blocks.apply(init_weights) + + torch.manual_seed(0) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + images = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": images, + } + + return inputs + + def test_control_guidance_switch(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + + scale = 10.0 + steps = 4 + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_1 = pipe(**inputs)[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_2 = pipe(**inputs, control_guidance_start=0.1, control_guidance_end=0.2)[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_3 = pipe( + **inputs, + control_guidance_start=[0.1], + control_guidance_end=[0.2], + )[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["num_inference_steps"] = steps + inputs["controlnet_conditioning_scale"] = scale + output_4 = pipe(**inputs, control_guidance_start=0.4, control_guidance_end=[0.5])[0] + + # make sure that all outputs are different + assert np.sum(np.abs(output_1 - output_2)) > 1e-3 + assert np.sum(np.abs(output_1 - output_3)) > 1e-3 + assert np.sum(np.abs(output_1 - output_4)) > 1e-3 + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3)