mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from .attention import AttentionTesterMixin, ContextParallelTesterMixin
|
||||
from .common import ModelTesterMixin
|
||||
from .common import BaseModelTesterConfig, ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraTesterMixin
|
||||
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesTesterMixin,
|
||||
@@ -17,14 +17,16 @@ from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ContextParallelTesterMixin",
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"ContextParallelTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
"LayerwiseCastingTesterMixin",
|
||||
"LoraHotSwappingForModelTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptTesterMixin",
|
||||
|
||||
@@ -25,7 +25,13 @@ from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
is_attention,
|
||||
is_context_parallel,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_attention
|
||||
@@ -89,8 +95,12 @@ class AttentionTesterMixin:
|
||||
output_after_fusion = output_after_fusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs match
|
||||
assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), (
|
||||
"Output should not change after fusing projections"
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_fusion,
|
||||
atol=self.base_precision,
|
||||
rtol=0,
|
||||
msg="Output should not change after fusing projections",
|
||||
)
|
||||
|
||||
# Unfuse projections
|
||||
@@ -110,8 +120,12 @@ class AttentionTesterMixin:
|
||||
output_after_unfusion = output_after_unfusion.to_tuple()[0]
|
||||
|
||||
# Verify outputs still match
|
||||
assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), (
|
||||
"Output should match original after unfusing projections"
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_unfusion,
|
||||
atol=self.base_precision,
|
||||
rtol=0,
|
||||
msg="Output should match original after unfusing projections",
|
||||
)
|
||||
|
||||
def test_get_set_processor(self):
|
||||
@@ -238,9 +252,6 @@ class ContextParallelTesterMixin:
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
|
||||
pytest.skip("Context parallel requires at least 2 CUDA devices.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -26,7 +26,7 @@ from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, d
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
|
||||
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||
|
||||
from ...testing_utils import CaptureLogger, torch_device
|
||||
from ...testing_utils import assert_tensors_close, torch_device
|
||||
|
||||
|
||||
def named_persistent_module_tensors(
|
||||
@@ -130,40 +130,144 @@ def check_device_map_is_respected(model, device_map):
|
||||
)
|
||||
|
||||
|
||||
class BaseModelTesterConfig:
|
||||
"""
|
||||
Base class defining the configuration interface for model testing.
|
||||
|
||||
This class defines the contract that all model test classes must implement.
|
||||
It provides a consistent interface for accessing model configuration, initialization
|
||||
parameters, and test inputs across all testing mixins.
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties (can be overridden, have sensible defaults):
|
||||
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
|
||||
- output_shape: Expected output shape for output validation tests (default: None)
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example usage:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return MyModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "org/my-model"
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 3, 32, 32)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {"in_channels": 3, "out_channels": 3}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
# ==================== Required Properties ====================
|
||||
|
||||
@property
|
||||
def model_class(self) -> Type[nn.Module]:
|
||||
"""The model class to test. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `model_class` property.")
|
||||
|
||||
# ==================== Optional Properties ====================
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self) -> Optional[str]:
|
||||
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self) -> Dict[str, Any]:
|
||||
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Optional[tuple]:
|
||||
"""Expected output shape for output validation tests."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
"""Percentages for model parallelism tests."""
|
||||
return [0.5, 0.7]
|
||||
|
||||
# ==================== Required Methods ====================
|
||||
|
||||
def get_init_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of arguments to initialize the model.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Initialization arguments for the model constructor.
|
||||
|
||||
Example:
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"sample_size": 32,
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
|
||||
|
||||
def get_dummy_inputs(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of inputs to pass to the model forward pass.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Input tensors/values for model.forward().
|
||||
|
||||
Example:
|
||||
return {
|
||||
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
|
||||
"timestep": torch.tensor([1], device=torch_device),
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
This mixin expects the test class to also inherit from BaseModelTesterConfig
|
||||
(or implement its interface) which provides:
|
||||
- model_class: The model class to test
|
||||
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
model_class = MyModel
|
||||
def get_init_dict(self): ...
|
||||
def get_dummy_inputs(self): ...
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
model_class = None
|
||||
base_precision = 1e-3
|
||||
model_split_percents = [0.5, 0.7]
|
||||
|
||||
def get_init_dict(self):
|
||||
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
raise NotImplementedError("get_dummy_inputs must be implemented by subclasses. It should return inputs_dict.")
|
||||
|
||||
def test_from_save_pretrained(self, expected_max_diff=5e-5):
|
||||
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0):
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname)
|
||||
new_model.to(torch_device)
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path)
|
||||
new_model.to(torch_device)
|
||||
|
||||
# check if all parameters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
@@ -184,28 +288,24 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert max_diff <= expected_max_diff, (
|
||||
f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
)
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
|
||||
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
|
||||
model.save_pretrained(tmp_path, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmpdirname)
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmp_path)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs())
|
||||
@@ -217,35 +317,27 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
assert max_diff <= expected_max_diff, (
|
||||
f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
)
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
def test_from_save_pretrained_dtype(self):
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if (
|
||||
hasattr(self.model_class, "_keep_in_fp32_modules")
|
||||
and self.model_class._keep_in_fp32_modules is None
|
||||
):
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(
|
||||
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
|
||||
)
|
||||
assert new_model.dtype == dtype
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
|
||||
|
||||
def test_determinism(self, expected_max_diff=1e-5):
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self, atol=1e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -259,18 +351,15 @@ class ModelTesterMixin:
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
# Remove NaN values and compute max difference
|
||||
# Filter out NaN values before comparison
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
|
||||
# Filter out NaN values
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
max_diff = torch.abs(first_filtered - second_filtered).max().item()
|
||||
assert max_diff <= expected_max_diff, (
|
||||
f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
|
||||
assert_tensors_close(
|
||||
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
|
||||
)
|
||||
|
||||
def test_output(self, expected_output_shape=None):
|
||||
@@ -310,13 +399,12 @@ class ModelTesterMixin:
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
), (
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
assert_tensors_close(
|
||||
set_nan_tensor_to_zero(tuple_object),
|
||||
set_nan_tensor_to_zero(dict_object),
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Tuple and dict output are not equal",
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
@@ -329,7 +417,7 @@ class ModelTesterMixin:
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_getattr_is_correct(self):
|
||||
def test_getattr_is_correct(self, caplog):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
@@ -337,28 +425,26 @@ class ModelTesterMixin:
|
||||
model.dummy_attribute = 5
|
||||
model.register_to_config(test_attribute=5)
|
||||
|
||||
logger = logging.get_logger("diffusers.models.modeling_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
logger_name = "diffusers.models.modeling_utils"
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "dummy_attribute")
|
||||
assert getattr(model, "dummy_attribute") == 5
|
||||
assert model.dummy_attribute == 5
|
||||
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
assert caplog.text == ""
|
||||
|
||||
logger = logging.get_logger("diffusers.models.modeling_utils")
|
||||
# 30 for warning
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "save_pretrained")
|
||||
fn = model.save_pretrained
|
||||
fn_1 = getattr(model, "save_pretrained")
|
||||
|
||||
assert fn == fn_1
|
||||
|
||||
# no warning should be thrown
|
||||
assert cap_logger.out == ""
|
||||
assert caplog.text == ""
|
||||
|
||||
# warning should be thrown for config attributes accessed directly
|
||||
with pytest.warns(FutureWarning):
|
||||
@@ -399,32 +485,34 @@ class ModelTesterMixin:
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be use for inference with an accelerator",
|
||||
)
|
||||
def test_from_save_pretrained_float16_bfloat16(self):
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for torch_dtype in [torch.bfloat16, torch.float16]:
|
||||
model.to(torch_dtype).save_pretrained(tmp_dir)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device)
|
||||
model.to(dtype).save_pretrained(tmp_path)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == torch_dtype
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == dtype
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**self.get_dummy_inputs())
|
||||
output_loaded = model_loaded(**self.get_dummy_inputs())
|
||||
with torch.no_grad():
|
||||
output = model(**self.get_dummy_inputs())
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert torch.allclose(output, output_loaded, atol=1e-4), (
|
||||
f"Loaded model output differs for {torch_dtype}"
|
||||
)
|
||||
output_loaded = model_loaded(**self.get_dummy_inputs())
|
||||
if isinstance(output_loaded, dict):
|
||||
output_loaded = output_loaded.to_tuple()[0]
|
||||
|
||||
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints(self):
|
||||
def test_sharded_checkpoints(self, tmp_path):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
@@ -435,30 +523,30 @@ class ModelTesterMixin:
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
new_model = self.model_class.from_pretrained(tmp_path).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
|
||||
"Output should match after sharded save/load"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
|
||||
)
|
||||
|
||||
@require_accelerator
|
||||
def test_sharded_checkpoints_with_variant(self):
|
||||
def test_sharded_checkpoints_with_variant(self, tmp_path):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
@@ -470,35 +558,33 @@ class ModelTesterMixin:
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
variant = "fp16"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(os.path.join(tmp_dir, index_filename)), (
|
||||
f"Variant index file {index_filename} should exist"
|
||||
)
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
|
||||
f"Variant index file {index_filename} should exist"
|
||||
)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
|
||||
"Output should match after variant sharded save/load"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new)
|
||||
|
||||
def test_sharded_checkpoints_with_parallel_loading(self):
|
||||
import time
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
|
||||
)
|
||||
|
||||
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
|
||||
from diffusers.utils import constants
|
||||
|
||||
torch.manual_seed(0)
|
||||
@@ -517,47 +603,37 @@ class ModelTesterMixin:
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
start_time = time.time()
|
||||
model_sequential = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
sequential_load_time = time.time() - start_time
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
torch.manual_seed(0)
|
||||
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
start_time = time.time()
|
||||
model_parallel = self.model_class.from_pretrained(tmp_dir).eval()
|
||||
parallel_load_time = time.time() - start_time
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel)
|
||||
assert_tensors_close(
|
||||
base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading"
|
||||
)
|
||||
|
||||
assert torch.allclose(base_output[0], output_parallel[0], atol=1e-5), (
|
||||
"Output should match with parallel loading"
|
||||
)
|
||||
|
||||
# Verify parallel loading is faster or at least not significantly slower
|
||||
assert parallel_load_time < sequential_load_time, (
|
||||
f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
|
||||
)
|
||||
finally:
|
||||
# Restore original values
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
@@ -565,7 +641,7 @@ class ModelTesterMixin:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
def test_model_parallelism(self, tmp_path):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
@@ -581,20 +657,19 @@ class ModelTesterMixin:
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
model.cpu().save_pretrained(tmp_path)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
|
||||
"Output should match with model parallelism"
|
||||
)
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism"
|
||||
)
|
||||
|
||||
@@ -13,17 +13,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import is_lora, require_peft_backend, torch_device
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_lora,
|
||||
is_torch_compile,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
@@ -67,7 +84,7 @@ class LoraTesterMixin:
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
@@ -95,26 +112,25 @@ class LoraTesterMixin:
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")), (
|
||||
"LoRA weights file not created"
|
||||
)
|
||||
model.save_lora_adapter(tmp_path)
|
||||
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
|
||||
"LoRA weights file not created"
|
||||
)
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}"
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert_tensors_close(loaded_v, retrieved_v, atol=1e-5, rtol=0, msg=f"Mismatch in LoRA weight {k}")
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
@@ -122,11 +138,15 @@ class LoraTesterMixin:
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
|
||||
"Outputs should match before and after save/load"
|
||||
assert_tensors_close(
|
||||
outputs_with_lora,
|
||||
outputs_with_lora_2,
|
||||
atol=1e-4,
|
||||
rtol=1e-4,
|
||||
msg="Outputs should match before and after save/load",
|
||||
)
|
||||
|
||||
def test_lora_wrong_adapter_name_raises_error(self):
|
||||
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
@@ -142,14 +162,13 @@ class LoraTesterMixin:
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
|
||||
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False):
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
@@ -166,19 +185,18 @@ class LoraTesterMixin:
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
@@ -196,25 +214,337 @@ class LoraTesterMixin:
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
|
||||
|
||||
@is_lora
|
||||
@is_torch_compile
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@require_torch_version_greater("2.7.1")
|
||||
@require_torch_accelerator
|
||||
class LoraHotSwappingForModelTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA hot swapping functionality on models.
|
||||
|
||||
Test that hotswapping does not result in recompilation on the model directly.
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
|
||||
and is extensively tested there. The goal of this test is specifically to ensure that
|
||||
hotswapping with diffusers does not require recompilation.
|
||||
|
||||
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||
for the analogous PEFT test.
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest marks: lora, torch_compile
|
||||
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
|
||||
"""
|
||||
|
||||
different_shapes_for_compilation = None
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def teardown_method(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
def _get_linear_module_name_other_than_attn(self, model):
|
||||
linear_names = [
|
||||
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
||||
]
|
||||
return linear_names[0]
|
||||
|
||||
def _check_model_hotswap(self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None):
|
||||
"""
|
||||
Check that hotswapping works on a model.
|
||||
|
||||
Steps:
|
||||
- create 2 LoRA adapters and save them
|
||||
- load the first adapter
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
- optionally check if recompilations happen on different shapes
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
different_shapes = self.different_shapes_for_compilation
|
||||
# create 2 adapters with different ranks and alphas
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
model.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output0_before = model(**inputs_dict)["sample"]
|
||||
|
||||
model.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
model.set_adapter("adapter1")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output1_before = model(**inputs_dict)["sample"]
|
||||
|
||||
# sanity checks:
|
||||
tol = 5e-3
|
||||
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
|
||||
assert not (output0_before == 0).all()
|
||||
assert not (output1_before == 0).all()
|
||||
|
||||
# save the adapter checkpoints
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del model
|
||||
|
||||
# load the first adapter
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
model.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
||||
|
||||
with torch.inference_mode():
|
||||
# additionally check if dynamic compilation works.
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output0_before, output0_after, atol=tol, rtol=tol, msg="Output mismatch after loading adapter0"
|
||||
)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output1_before,
|
||||
output1_after,
|
||||
atol=tol,
|
||||
rtol=tol,
|
||||
msg="Output mismatch after hotswapping to adapter1",
|
||||
)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_model(self, tmp_path, rank0, rank1):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
|
||||
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
||||
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
||||
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
||||
# block.
|
||||
target_modules = ["to_q"]
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
target_modules.append(self._get_linear_module_name_other_than_attn(model))
|
||||
del model
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
|
||||
different_shapes_for_compilation = self.different_shapes_for_compilation
|
||||
if different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
||||
# variable to represent input sizes that are the same. For more details,
|
||||
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=rank0,
|
||||
rank1=rank1,
|
||||
target_modules0=target_modules,
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
@@ -122,8 +123,8 @@ class CPUOffloadTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
|
||||
"Output should match with CPU offloading"
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
@@ -156,7 +157,9 @@ class CPUOffloadTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading"
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
@@ -183,8 +186,12 @@ class CPUOffloadTesterMixin:
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
|
||||
"Output should match with disk offloading (safetensors)"
|
||||
assert_tensors_close(
|
||||
base_output[0],
|
||||
new_output[0],
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with disk offloading (safetensors)",
|
||||
)
|
||||
|
||||
|
||||
@@ -247,17 +254,33 @@ class GroupOffloadTesterMixin:
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5), (
|
||||
"Output should match with block-level offloading"
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading1,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with block-level offloading",
|
||||
)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5), (
|
||||
"Output should match with non-blocking block-level offloading"
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading2,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with non-blocking block-level offloading",
|
||||
)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5), (
|
||||
"Output should match with leaf-level offloading"
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading3,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with leaf-level offloading",
|
||||
)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5), (
|
||||
"Output should match with leaf-level offloading with stream"
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading4,
|
||||
atol=1e-5,
|
||||
rtol=0,
|
||||
msg="Output should match with leaf-level offloading with stream",
|
||||
)
|
||||
|
||||
@require_group_offload_support
|
||||
@@ -345,8 +368,12 @@ class GroupOffloadTesterMixin:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol), (
|
||||
"Output should match with disk-based group offloading"
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading,
|
||||
atol=atol,
|
||||
rtol=0,
|
||||
msg="Output should match with disk-based group offloading",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_single_file,
|
||||
nightly,
|
||||
@@ -146,9 +147,8 @@ class SingleFileTesterMixin:
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
|
||||
f"Parameter values differ for {key}: "
|
||||
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
|
||||
assert_tensors_close(
|
||||
param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_local_files_only(self):
|
||||
|
||||
@@ -23,13 +23,14 @@ from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProc
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
@@ -94,10 +95,26 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||
|
||||
|
||||
class FluxTransformerTesterConfig:
|
||||
model_class = FluxTransformer2DModel
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
|
||||
pretrained_model_kwargs = {"subfolder": "transformer"}
|
||||
class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return FluxTransformer2DModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
@@ -136,14 +153,6 @@ class FluxTransformerTesterConfig:
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
|
||||
@@ -131,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs):
|
||||
return True
|
||||
|
||||
|
||||
def assert_tensors_close(
|
||||
actual: "torch.Tensor",
|
||||
expected: "torch.Tensor",
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-5,
|
||||
msg: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Assert that two tensors are close within tolerance.
|
||||
|
||||
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
|
||||
Provides concise, actionable error messages without dumping full tensors.
|
||||
|
||||
Args:
|
||||
actual: The actual tensor from the computation.
|
||||
expected: The expected tensor to compare against.
|
||||
atol: Absolute tolerance.
|
||||
rtol: Relative tolerance.
|
||||
msg: Optional message prefix for the assertion error.
|
||||
|
||||
Raises:
|
||||
AssertionError: If tensors have different shapes or values exceed tolerance.
|
||||
|
||||
Example:
|
||||
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise ValueError("PyTorch needs to be installed to use this function.")
|
||||
|
||||
if actual.shape != expected.shape:
|
||||
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
|
||||
|
||||
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
|
||||
abs_diff = (actual - expected).abs()
|
||||
max_diff = abs_diff.max().item()
|
||||
|
||||
flat_idx = abs_diff.argmax().item()
|
||||
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
|
||||
|
||||
threshold = atol + rtol * expected.abs()
|
||||
mismatched = (abs_diff > threshold).sum().item()
|
||||
total = actual.numel()
|
||||
|
||||
raise AssertionError(
|
||||
f"{msg}\n"
|
||||
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
|
||||
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
|
||||
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
|
||||
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
|
||||
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
|
||||
)
|
||||
|
||||
|
||||
def numpy_cosine_similarity_distance(a, b):
|
||||
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
||||
distance = 1.0 - similarity.mean()
|
||||
|
||||
Reference in New Issue
Block a user