From 86714b72d07fe802129a8d892f135c4b564231fe Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 2 Jan 2024 14:40:46 +0100 Subject: [PATCH] Add unload_ip_adapter method (#6192) * Add unload_ip_adapter method * Update attn_processors with original layers * Add test * Use set_default_attn_processor --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/ip_adapter.py | 33 ++++++++++++++++++- .../test_ip_adapter_stable_diffusion.py | 20 +++++++++++ 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 3df0492380..039b6b910a 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -132,7 +132,7 @@ class IPAdapterMixin: if keys != ["image_proj", "ip_adapter"]: raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") - # load CLIP image encoer here if it has not been registered to the pipeline yet + # load CLIP image encoder here if it has not been registered to the pipeline yet if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: if not isinstance(pretrained_model_name_or_path_or_dict, dict): logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") @@ -141,12 +141,14 @@ class IPAdapterMixin: subfolder=os.path.join(subfolder, "image_encoder"), ).to(self.device, dtype=self.dtype) self.image_encoder = image_encoder + self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"]) else: raise ValueError("`image_encoder` cannot be None when using IP Adapters.") # create feature extractor if it has not been registered to the pipeline yet if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: self.feature_extractor = CLIPImageProcessor() + self.register_to_config(feature_extractor=["transformers", "CLIPImageProcessor"]) # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet @@ -157,3 +159,32 @@ class IPAdapterMixin: for attn_processor in unet.attn_processors.values(): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): attn_processor.scale = scale + + def unload_ip_adapter(self): + """ + Unloads the IP Adapter weights + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # remove CLIP image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=[None, None]) + + # remove feature extractor + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=[None, None]) + + # remove hidden encoder + self.unet.encoder_hid_proj = None + self.config.encoder_hid_dim_type = None + + # restore original Unet attention processors layers + self.unet.set_default_attn_processor() diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index dfc39d61bb..289d2b7d65 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -31,6 +31,7 @@ from diffusers import ( StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, ) +from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -228,6 +229,25 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + def test_unload(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") + pipeline.set_ip_adapter_scale(0.7) + + pipeline.unload_ip_adapter() + + assert getattr(pipeline, "image_encoder") is None + assert getattr(pipeline, "feature_extractor") is None + processors = [ + isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0)) + for name, attn_proc in pipeline.unet.attn_processors.items() + ] + assert processors == [True] * len(processors) + @slow @require_torch_gpu