mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
159 lines
5.9 KiB
Python
159 lines
5.9 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import gc
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
|
|
|
|
|
|
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
|
|
"""
|
|
Check if IP Adapter processors are correctly set in the model.
|
|
|
|
Args:
|
|
model: The model to check
|
|
|
|
Returns:
|
|
bool: True if IP Adapter is correctly set, False otherwise
|
|
"""
|
|
for module in model.attn_processors.values():
|
|
if isinstance(module, processor_cls):
|
|
return True
|
|
return False
|
|
|
|
|
|
@is_ip_adapter
|
|
class IPAdapterTesterMixin:
|
|
"""
|
|
Mixin class for testing IP Adapter functionality on models.
|
|
|
|
Expected from config mixin:
|
|
- model_class: The model class to test
|
|
|
|
Required properties (must be implemented by subclasses):
|
|
- ip_adapter_processor_cls: The IP Adapter processor class to use
|
|
|
|
Required methods (must be implemented by subclasses):
|
|
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
|
|
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
|
|
|
|
Expected methods from config mixin:
|
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: ip_adapter
|
|
Use `pytest -m "not ip_adapter"` to skip these tests
|
|
"""
|
|
|
|
def setup_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
|
|
def teardown_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
|
|
@property
|
|
def ip_adapter_processor_cls(self):
|
|
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
|
|
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
|
|
|
|
def create_ip_adapter_state_dict(self, model):
|
|
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
|
|
|
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
|
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
|
|
|
|
@torch.no_grad()
|
|
def test_load_ip_adapter(self):
|
|
init_dict = self.get_init_dict()
|
|
inputs_dict = self.get_dummy_inputs()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
torch.manual_seed(0)
|
|
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
|
|
|
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
|
|
|
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
|
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
|
"IP Adapter processors not set correctly"
|
|
)
|
|
|
|
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
|
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
|
|
|
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
|
|
"Output should differ with IP Adapter enabled"
|
|
)
|
|
|
|
@pytest.mark.skip(
|
|
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
|
|
)
|
|
def test_ip_adapter_scale(self):
|
|
init_dict = self.get_init_dict()
|
|
inputs_dict = self.get_dummy_inputs()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
|
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
|
|
|
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
|
|
|
# Test scale = 0.0 (no effect)
|
|
model.set_ip_adapter_scale(0.0)
|
|
torch.manual_seed(0)
|
|
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
|
|
|
# Test scale = 1.0 (full effect)
|
|
model.set_ip_adapter_scale(1.0)
|
|
torch.manual_seed(0)
|
|
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
|
|
|
# Outputs should differ with different scales
|
|
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
|
|
"Output should differ with different IP Adapter scales"
|
|
)
|
|
|
|
@pytest.mark.skip(
|
|
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
|
|
)
|
|
def test_unload_ip_adapter(self):
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
# Save original processors
|
|
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
|
|
|
# Create and load IP adapter
|
|
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
|
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
|
|
|
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
|
|
|
|
# Unload IP adapter
|
|
model.unload_ip_adapter()
|
|
|
|
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
|
"IP Adapter should be unloaded"
|
|
)
|
|
|
|
# Verify processors are restored
|
|
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
|
assert original_processors == current_processors, "Processors should be restored after unload"
|