1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/tests/models/testing_utils/ip_adapter.py
2026-01-13 10:38:16 +05:30

159 lines
5.9 KiB
Python

# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, processor_cls):
return True
return False
@is_ip_adapter
class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected from config mixin:
- model_class: The model class to test
Required properties (must be implemented by subclasses):
- ip_adapter_processor_cls: The IP Adapter processor class to use
Required methods (must be implemented by subclasses):
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: ip_adapter
Use `pytest -m "not ip_adapter"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@property
def ip_adapter_processor_cls(self):
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
@torch.no_grad()
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter processors not set correctly"
)
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
"Output should differ with IP Adapter enabled"
)
@pytest.mark.skip(
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
)
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Test scale = 1.0 (full effect)
model.set_ip_adapter_scale(1.0)
torch.manual_seed(0)
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
"Output should differ with different IP Adapter scales"
)
@pytest.mark.skip(
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
)
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter should be unloaded"
)
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"