1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Restore AttnProcessor2_0 in unload_ip_adapter (#7727)

* Restore AttnProcessor2_0 in unload_ip_adapter

* Fix style

* Update test

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
Fabio Rigano
2024-04-23 01:59:03 +02:00
committed by GitHub
parent 21c747fa0f
commit 065f251766
2 changed files with 18 additions and 7 deletions

View File

@@ -16,6 +16,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open
@@ -38,6 +39,8 @@ if is_transformers_available():
)
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
@@ -323,4 +326,14 @@ class IPAdapterMixin:
self.config.encoder_hid_dim_type = None
# restore original Unet attention processors layers
self.unet.set_default_attn_processor()
attn_procs = {}
for name, value in self.unet.attn_processors.items():
attn_processor_class = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
)
attn_procs[name] = (
attn_processor_class
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)

View File

@@ -32,7 +32,6 @@ from diffusers import (
StableDiffusionXLPipeline,
)
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -307,6 +306,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
before_processors = [attn_proc.__class__ for attn_proc in pipeline.unet.attn_processors.values()]
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipeline.set_ip_adapter_scale(0.7)
@@ -315,11 +315,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
assert getattr(pipeline, "image_encoder") is None
assert getattr(pipeline, "feature_extractor") is not None
processors = [
isinstance(attn_proc, (AttnProcessor, AttnProcessor2_0))
for name, attn_proc in pipeline.unet.attn_processors.items()
]
assert processors == [True] * len(processors)
after_processors = [attn_proc.__class__ for attn_proc in pipeline.unet.attn_processors.values()]
assert before_processors == after_processors
@is_flaky
def test_multi(self):