mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
182 lines
6.5 KiB
Python
182 lines
6.5 KiB
Python
# coding=utf-8
|
|
# Copyright 2025 HuggingFace Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import gc
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from diffusers.models.attention import AttentionModuleMixin
|
|
from diffusers.models.attention_processor import (
|
|
AttnProcessor,
|
|
)
|
|
|
|
from ...testing_utils import (
|
|
assert_tensors_close,
|
|
backend_empty_cache,
|
|
is_attention,
|
|
torch_device,
|
|
)
|
|
|
|
|
|
@is_attention
|
|
class AttentionTesterMixin:
|
|
"""
|
|
Mixin class for testing attention processor and module functionality on models.
|
|
|
|
Tests functionality from AttentionModuleMixin including:
|
|
- Attention processor management (set/get)
|
|
- QKV projection fusion/unfusion
|
|
- Attention backends (XFormers, NPU, etc.)
|
|
|
|
Expected from config mixin:
|
|
- model_class: The model class to test
|
|
|
|
Expected methods from config mixin:
|
|
- get_init_dict(): Returns dict of arguments to initialize the model
|
|
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
|
|
|
Pytest mark: attention
|
|
Use `pytest -m "not attention"` to skip these tests
|
|
"""
|
|
|
|
def setup_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
|
|
def teardown_method(self):
|
|
gc.collect()
|
|
backend_empty_cache(torch_device)
|
|
|
|
@torch.no_grad()
|
|
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
|
|
init_dict = self.get_init_dict()
|
|
inputs_dict = self.get_dummy_inputs()
|
|
model = self.model_class(**init_dict)
|
|
model.to(torch_device)
|
|
model.eval()
|
|
|
|
if not hasattr(model, "fuse_qkv_projections"):
|
|
pytest.skip("Model does not support QKV projection fusion.")
|
|
|
|
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
|
|
|
|
model.fuse_qkv_projections()
|
|
|
|
has_fused_projections = False
|
|
for module in model.modules():
|
|
if isinstance(module, AttentionModuleMixin):
|
|
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
|
has_fused_projections = True
|
|
assert module.fused_projections, "fused_projections flag should be True"
|
|
break
|
|
|
|
if has_fused_projections:
|
|
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
|
|
|
|
assert_tensors_close(
|
|
output_before_fusion,
|
|
output_after_fusion,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
msg="Output should not change after fusing projections",
|
|
)
|
|
|
|
model.unfuse_qkv_projections()
|
|
|
|
for module in model.modules():
|
|
if isinstance(module, AttentionModuleMixin):
|
|
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
|
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
|
assert not module.fused_projections, "fused_projections flag should be False"
|
|
|
|
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
|
|
|
|
assert_tensors_close(
|
|
output_before_fusion,
|
|
output_after_unfusion,
|
|
atol=atol,
|
|
rtol=rtol,
|
|
msg="Output should match original after unfusing projections",
|
|
)
|
|
|
|
def test_get_set_processor(self):
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict)
|
|
model.to(torch_device)
|
|
|
|
# Check if model has attention processors
|
|
if not hasattr(model, "attn_processors"):
|
|
pytest.skip("Model does not have attention processors.")
|
|
|
|
# Test getting processors
|
|
processors = model.attn_processors
|
|
assert isinstance(processors, dict), "attn_processors should return a dict"
|
|
assert len(processors) > 0, "Model should have at least one attention processor"
|
|
|
|
# Test that all processors can be retrieved via get_processor
|
|
for module in model.modules():
|
|
if isinstance(module, AttentionModuleMixin):
|
|
processor = module.get_processor()
|
|
assert processor is not None, "get_processor should return a processor"
|
|
|
|
# Test setting a new processor
|
|
new_processor = AttnProcessor()
|
|
module.set_processor(new_processor)
|
|
retrieved_processor = module.get_processor()
|
|
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
|
|
|
def test_attention_processor_dict(self):
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict)
|
|
model.to(torch_device)
|
|
|
|
if not hasattr(model, "set_attn_processor"):
|
|
pytest.skip("Model does not support setting attention processors.")
|
|
|
|
# Get current processors
|
|
current_processors = model.attn_processors
|
|
|
|
# Create a dict of new processors
|
|
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
|
|
|
# Set processors using dict
|
|
model.set_attn_processor(new_processors)
|
|
|
|
# Verify all processors were set
|
|
updated_processors = model.attn_processors
|
|
for key in current_processors.keys():
|
|
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
|
|
|
def test_attention_processor_count_mismatch_raises_error(self):
|
|
init_dict = self.get_init_dict()
|
|
model = self.model_class(**init_dict)
|
|
model.to(torch_device)
|
|
|
|
if not hasattr(model, "set_attn_processor"):
|
|
pytest.skip("Model does not support setting attention processors.")
|
|
|
|
# Get current processors
|
|
current_processors = model.attn_processors
|
|
|
|
# Create a dict with wrong number of processors
|
|
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
|
|
|
# Verify error is raised
|
|
with pytest.raises(ValueError) as exc_info:
|
|
model.set_attn_processor(wrong_processors)
|
|
|
|
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|