1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/tests/models/testing_utils/attention.py
2025-12-11 11:04:47 +05:30

264 lines
10 KiB
Python

# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device
@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
base_precision = 1e-3
def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
# Get output before fusion
with torch.no_grad():
output_before_fusion = model(**inputs_dict)
if isinstance(output_before_fusion, dict):
output_before_fusion = output_before_fusion.to_tuple()[0]
# Fuse projections
model.fuse_qkv_projections()
# Verify fusion occurred by checking for fused attributes
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
# Get output after fusion
with torch.no_grad():
output_after_fusion = model(**inputs_dict)
if isinstance(output_after_fusion, dict):
output_after_fusion = output_after_fusion.to_tuple()[0]
# Verify outputs match
assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), (
"Output should not change after fusing projections"
)
# Unfuse projections
model.unfuse_qkv_projections()
# Verify unfusion occurred
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
# Get output after unfusion
with torch.no_grad():
output_after_unfusion = model(**inputs_dict)
if isinstance(output_after_unfusion, dict):
output_after_unfusion = output_after_unfusion.to_tuple()[0]
# Verify outputs still match
assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), (
"Output should match original after unfusing projections"
)
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
try:
# Setup distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device)
if isinstance(output, dict):
output = output.to_tuple()[0]
if rank == 0:
result_queue.put(("success", output.shape))
except Exception as e:
if rank == 0:
result_queue.put(("error", str(e)))
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
base_precision = 1e-3
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("Context parallel requires at least 2 CUDA devices.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
cp_dict = {cp_type: world_size}
ctx = mp.get_context("spawn")
result_queue = ctx.Queue()
mp.spawn(
_context_parallel_worker,
args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue),
nprocs=world_size,
join=True,
)
status, result = result_queue.get(timeout=60)
assert status == "success", f"Context parallel inference failed: {result}"