diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index cb158a4bc1..28a4334b19 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -16,6 +16,7 @@ from pathlib import Path from typing import Dict, List, Optional, Union import torch +import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args from safetensors import safe_open @@ -38,6 +39,8 @@ if is_transformers_available(): ) from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) @@ -323,4 +326,14 @@ class IPAdapterMixin: self.config.encoder_hid_dim_type = None # restore original Unet attention processors layers - self.unet.set_default_attn_processor() + attn_procs = {} + for name, value in self.unet.attn_processors.items(): + attn_processor_class = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + ) + attn_procs[name] = ( + attn_processor_class + if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)) + else value.__class__() + ) + self.unet.set_attn_processor(attn_procs) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 3a5ff03e56..8c95fbc703 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -32,7 +32,6 @@ from diffusers import ( StableDiffusionXLPipeline, ) from diffusers.image_processor import IPAdapterMaskProcessor -from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, @@ -307,6 +306,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): pipeline = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype ) + before_processors = [attn_proc.__class__ for attn_proc in pipeline.unet.attn_processors.values()] pipeline.to(torch_device) pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin") pipeline.set_ip_adapter_scale(0.7) @@ -315,11 +315,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin): assert getattr(pipeline, "image_encoder") is None assert getattr(pipeline, "feature_extractor") is not None - processors = [ - isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0)) - for name, attn_proc in pipeline.unet.attn_processors.items() - ] - assert processors == [True] * len(processors) + after_processors = [attn_proc.__class__ for attn_proc in pipeline.unet.attn_processors.values()] + + assert before_processors == after_processors @is_flaky def test_multi(self):