From ca61287daa24bcd7b3a8ed9159cad44596df5e49 Mon Sep 17 00:00:00 2001 From: Stephen Date: Sat, 30 Mar 2024 12:15:29 -0400 Subject: [PATCH] Fix IP Adapter Support for SAG Pipeline (#7260) * fix ip adapter support * Update sag pipelines tests, adjust sag pipeline to pass tests --------- Co-authored-by: YiYi Xu --- .../pipeline_stable_diffusion_sag.py | 73 +++++++++++++++++-- .../test_stable_diffusion_sag.py | 6 +- 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py index 82d7474ac4..7dfefd94da 100644 --- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -401,6 +401,40 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua return image_embeds, uncond_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -535,6 +569,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, @@ -583,6 +618,9 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -636,13 +674,24 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua # `sag_scale = 0` means no self-attention guidance do_self_attention_guidance = sag_scale > 0.0 - if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + do_classifier_free_guidance, ) + if do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = [] + negative_image_embeds = [] + for tmp_image_embeds in ip_adapter_image_embeds: + single_negative_image_embeds, single_image_embeds = tmp_image_embeds.chunk(2) + image_embeds.append(single_image_embeds) + negative_image_embeds.append(single_negative_image_embeds) + else: + image_embeds = ip_adapter_image_embeds # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( @@ -687,8 +736,18 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None - added_uncond_kwargs = {"image_embeds": negative_image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + if do_classifier_free_guidance: + added_uncond_kwargs = ( + {"image_embeds": negative_image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 7. Denoising loop store_processor = CrossAttnStoreProcessor() diff --git a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py index 8123df3da9..f210f3e75a 100644 --- a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py +++ b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py @@ -32,13 +32,15 @@ from diffusers import ( from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionSAGPipelineFastTests( + IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionSAGPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS