mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Restore AttnProcessor2_0 in unload_ip_adapter (#7727)
* Restore AttnProcessor2_0 in unload_ip_adapter * Fix style * Update test --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from safetensors import safe_open
|
||||
|
||||
@@ -38,6 +39,8 @@ if is_transformers_available():
|
||||
)
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
@@ -323,4 +326,14 @@ class IPAdapterMixin:
|
||||
self.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
self.unet.set_default_attn_processor()
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
@@ -32,7 +32,6 @@ from diffusers import (
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from diffusers.image_processor import IPAdapterMaskProcessor
|
||||
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
@@ -307,6 +306,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
|
||||
)
|
||||
before_processors = [attn_proc.__class__ for attn_proc in pipeline.unet.attn_processors.values()]
|
||||
pipeline.to(torch_device)
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
pipeline.set_ip_adapter_scale(0.7)
|
||||
@@ -315,11 +315,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
|
||||
assert getattr(pipeline, "image_encoder") is None
|
||||
assert getattr(pipeline, "feature_extractor") is not None
|
||||
processors = [
|
||||
isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0))
|
||||
for name, attn_proc in pipeline.unet.attn_processors.items()
|
||||
]
|
||||
assert processors == [True] * len(processors)
|
||||
after_processors = [attn_proc.__class__ for attn_proc in pipeline.unet.attn_processors.values()]
|
||||
|
||||
assert before_processors == after_processors
|
||||
|
||||
@is_flaky
|
||||
def test_multi(self):
|
||||
|
||||
Reference in New Issue
Block a user