1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-12-15 14:19:27 +05:30
parent c366b5a817
commit d08e0bb545
8 changed files with 793 additions and 286 deletions

View File

@@ -1,8 +1,8 @@
from .attention import AttentionTesterMixin, ContextParallelTesterMixin
from .common import ModelTesterMixin
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .quantization import (
BitsAndBytesTesterMixin,
@@ -17,14 +17,16 @@ from .training import TrainingTesterMixin
__all__ = [
"ContextParallelTesterMixin",
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptTesterMixin",

View File

@@ -25,7 +25,13 @@ from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import is_attention, is_context_parallel, require_torch_multi_accelerator, torch_device
from ...testing_utils import (
assert_tensors_close,
is_attention,
is_context_parallel,
require_torch_multi_accelerator,
torch_device,
)
@is_attention
@@ -89,8 +95,12 @@ class AttentionTesterMixin:
output_after_fusion = output_after_fusion.to_tuple()[0]
# Verify outputs match
assert torch.allclose(output_before_fusion, output_after_fusion, atol=self.base_precision), (
"Output should not change after fusing projections"
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=self.base_precision,
rtol=0,
msg="Output should not change after fusing projections",
)
# Unfuse projections
@@ -110,8 +120,12 @@ class AttentionTesterMixin:
output_after_unfusion = output_after_unfusion.to_tuple()[0]
# Verify outputs still match
assert torch.allclose(output_before_fusion, output_after_unfusion, atol=self.base_precision), (
"Output should match original after unfusing projections"
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=self.base_precision,
rtol=0,
msg="Output should match original after unfusing projections",
)
def test_get_set_processor(self):
@@ -238,9 +252,6 @@ class ContextParallelTesterMixin:
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("Context parallel requires at least 2 CUDA devices.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

View File

@@ -15,8 +15,8 @@
import json
import os
import tempfile
from collections import defaultdict
from typing import Any, Dict, Optional, Type
import pytest
import torch
@@ -26,7 +26,7 @@ from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, d
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import CaptureLogger, torch_device
from ...testing_utils import assert_tensors_close, torch_device
def named_persistent_module_tensors(
@@ -130,40 +130,144 @@ def check_device_map_is_respected(model, device_map):
)
class BaseModelTesterConfig:
"""
Base class defining the configuration interface for model testing.
This class defines the contract that all model test classes must implement.
It provides a consistent interface for accessing model configuration, initialization
parameters, and test inputs across all testing mixins.
Required properties (must be implemented by subclasses):
- model_class: The model class to test
Optional properties (can be overridden, have sensible defaults):
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
- output_shape: Expected output shape for output validation tests (default: None)
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
Required methods (must be implemented by subclasses):
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example usage:
class MyModelTestConfig(BaseModelTesterConfig):
@property
def model_class(self):
return MyModel
@property
def pretrained_model_name_or_path(self):
return "org/my-model"
@property
def output_shape(self):
return (1, 3, 32, 32)
def get_init_dict(self):
return {"in_channels": 3, "out_channels": 3}
def get_dummy_inputs(self):
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
pass
"""
# ==================== Required Properties ====================
@property
def model_class(self) -> Type[nn.Module]:
"""The model class to test. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `model_class` property.")
# ==================== Optional Properties ====================
@property
def pretrained_model_name_or_path(self) -> Optional[str]:
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
return None
@property
def pretrained_model_kwargs(self) -> Dict[str, Any]:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""
return None
@property
def model_split_percents(self) -> list:
"""Percentages for model parallelism tests."""
return [0.5, 0.7]
# ==================== Required Methods ====================
def get_init_dict(self) -> Dict[str, Any]:
"""
Returns dict of arguments to initialize the model.
Returns:
Dict[str, Any]: Initialization arguments for the model constructor.
Example:
return {
"in_channels": 3,
"out_channels": 3,
"sample_size": 32,
}
"""
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
def get_dummy_inputs(self) -> Dict[str, Any]:
"""
Returns dict of inputs to pass to the model forward pass.
Returns:
Dict[str, Any]: Input tensors/values for model.forward().
Example:
return {
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
"timestep": torch.tensor([1], device=torch_device),
}
"""
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
class ModelTesterMixin:
"""
Base mixin class for model testing with common test methods.
Expected class attributes to be set by subclasses:
This mixin expects the test class to also inherit from BaseModelTesterConfig
(or implement its interface) which provides:
- model_class: The model class to test
- main_input_name: Name of the main input tensor (e.g., "sample", "hidden_states")
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example:
class MyModelTestConfig(BaseModelTesterConfig):
model_class = MyModel
def get_init_dict(self): ...
def get_dummy_inputs(self): ...
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
pass
"""
model_class = None
base_precision = 1e-3
model_split_percents = [0.5, 0.7]
def get_init_dict(self):
raise NotImplementedError("get_init_dict must be implemented by subclasses. ")
def get_dummy_inputs(self):
raise NotImplementedError("get_dummy_inputs must be implemented by subclasses. It should return inputs_dict.")
def test_from_save_pretrained(self, expected_max_diff=5e-5):
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)
# check if all parameters shape are the same
for param_name in model.state_dict().keys():
@@ -184,28 +288,24 @@ class ModelTesterMixin:
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert max_diff <= expected_max_diff, (
f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
)
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
# non-variant cannot be loaded
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmpdirname)
# non-variant cannot be loaded
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
new_model.to(torch_device)
with torch.no_grad():
image = model(**self.get_dummy_inputs())
@@ -217,35 +317,27 @@ class ModelTesterMixin:
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
max_diff = (image - new_image).abs().max().item()
assert max_diff <= expected_max_diff, (
f"Models give different forward passes. Max diff: {max_diff}, expected: {expected_max_diff}"
)
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
def test_from_save_pretrained_dtype(self):
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype)
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if (
hasattr(self.model_class, "_keep_in_fp32_modules")
and self.model_class._keep_in_fp32_modules is None
):
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(
tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype
)
assert new_model.dtype == dtype
if torch_device == "mps" and dtype == torch.bfloat16:
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
def test_determinism(self, expected_max_diff=1e-5):
model.to(dtype)
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
def test_determinism(self, atol=1e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
@@ -259,18 +351,15 @@ class ModelTesterMixin:
if isinstance(second, dict):
second = second.to_tuple()[0]
# Remove NaN values and compute max difference
# Filter out NaN values before comparison
first_flat = first.flatten()
second_flat = second.flatten()
# Filter out NaN values
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
first_filtered = first_flat[mask]
second_filtered = second_flat[mask]
max_diff = torch.abs(first_filtered - second_filtered).max().item()
assert max_diff <= expected_max_diff, (
f"Model outputs are not deterministic. Max diff: {max_diff}, expected: {expected_max_diff}"
assert_tensors_close(
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)
def test_output(self, expected_output_shape=None):
@@ -310,13 +399,12 @@ class ModelTesterMixin:
elif tuple_object is None:
return
else:
assert torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
), (
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
assert_tensors_close(
set_nan_tensor_to_zero(tuple_object),
set_nan_tensor_to_zero(dict_object),
atol=1e-5,
rtol=0,
msg="Tuple and dict output are not equal",
)
model = self.model_class(**self.get_init_dict())
@@ -329,7 +417,7 @@ class ModelTesterMixin:
recursive_check(outputs_tuple, outputs_dict)
def test_getattr_is_correct(self):
def test_getattr_is_correct(self, caplog):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
@@ -337,28 +425,26 @@ class ModelTesterMixin:
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
logger_name = "diffusers.models.modeling_utils"
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert cap_logger.out == ""
assert caplog.text == ""
logger = logging.get_logger("diffusers.models.modeling_utils")
# 30 for warning
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert cap_logger.out == ""
assert caplog.text == ""
# warning should be thrown for config attributes accessed directly
with pytest.warns(FutureWarning):
@@ -399,32 +485,34 @@ class ModelTesterMixin:
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
def test_from_save_pretrained_float16_bfloat16(self):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules
with tempfile.TemporaryDirectory() as tmp_dir:
for torch_dtype in [torch.bfloat16, torch.float16]:
model.to(torch_dtype).save_pretrained(tmp_dir)
model_loaded = self.model_class.from_pretrained(tmp_dir, torch_dtype=torch_dtype).to(torch_device)
model.to(dtype).save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == torch_dtype
for name, param in model_loaded.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == dtype
with torch.no_grad():
output = model(**self.get_dummy_inputs())
output_loaded = model_loaded(**self.get_dummy_inputs())
with torch.no_grad():
output = model(**self.get_dummy_inputs())
if isinstance(output, dict):
output = output.to_tuple()[0]
assert torch.allclose(output, output_loaded, atol=1e-4), (
f"Loaded model output differs for {torch_dtype}"
)
output_loaded = model_loaded(**self.get_dummy_inputs())
if isinstance(output_loaded, dict):
output_loaded = output_loaded.to_tuple()[0]
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
@require_accelerator
def test_sharded_checkpoints(self):
def test_sharded_checkpoints(self, tmp_path):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
@@ -435,30 +523,30 @@ class ModelTesterMixin:
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
new_model = self.model_class.from_pretrained(tmp_path).eval()
new_model = new_model.to(torch_device)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match after sharded save/load"
)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after sharded save/load"
)
@require_accelerator
def test_sharded_checkpoints_with_variant(self):
def test_sharded_checkpoints_with_variant(self, tmp_path):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
@@ -470,35 +558,33 @@ class ModelTesterMixin:
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(os.path.join(tmp_dir, index_filename)), (
f"Variant index file {index_filename} should exist"
)
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
f"Variant index file {index_filename} should exist"
)
new_model = self.model_class.from_pretrained(tmp_dir, variant=variant).eval()
new_model = new_model.to(torch_device)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
new_model = new_model.to(torch_device)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match after variant sharded save/load"
)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new)
def test_sharded_checkpoints_with_parallel_loading(self):
import time
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match after variant sharded save/load"
)
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
from diffusers.utils import constants
torch.manual_seed(0)
@@ -517,47 +603,37 @@ class ModelTesterMixin:
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
start_time = time.time()
model_sequential = self.model_class.from_pretrained(tmp_dir).eval()
sequential_load_time = time.time() - start_time
model_sequential = model_sequential.to(torch_device)
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
model_sequential = model_sequential.to(torch_device)
torch.manual_seed(0)
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
torch.manual_seed(0)
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
model_parallel = model_parallel.to(torch_device)
start_time = time.time()
model_parallel = self.model_class.from_pretrained(tmp_dir).eval()
parallel_load_time = time.time() - start_time
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel)
assert_tensors_close(
base_output[0], output_parallel[0], atol=1e-5, rtol=0, msg="Output should match with parallel loading"
)
assert torch.allclose(base_output[0], output_parallel[0], atol=1e-5), (
"Output should match with parallel loading"
)
# Verify parallel loading is faster or at least not significantly slower
assert parallel_load_time < sequential_load_time, (
f"Parallel loading took {parallel_load_time:.4f}s, sequential took {sequential_load_time:.4f}s"
)
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
@@ -565,7 +641,7 @@ class ModelTesterMixin:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
@require_torch_multi_accelerator
def test_model_parallelism(self):
def test_model_parallelism(self, tmp_path):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
@@ -581,20 +657,19 @@ class ModelTesterMixin:
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
model.cpu().save_pretrained(tmp_path)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
check_device_map_is_respected(new_model, new_model.hf_device_map)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match with model parallelism"
)
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with model parallelism"
)

View File

@@ -13,17 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import tempfile
import re
import pytest
import safetensors.torch
import torch
import torch.nn as nn
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import is_lora, require_peft_backend, torch_device
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_lora,
is_torch_compile,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_torch_version_greater,
torch_device,
)
if is_peft_available():
from diffusers.loaders.peft import PeftAdapterMixin
def check_if_lora_correctly_set(model) -> bool:
@@ -67,7 +84,7 @@ class LoraTesterMixin:
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def test_save_load_lora_adapter(self, rank=4, lora_alpha=4, use_dora=False):
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
@@ -95,26 +112,25 @@ class LoraTesterMixin:
"Output should differ with LoRA enabled"
)
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
assert os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")), (
"LoRA weights file not created"
)
model.save_lora_adapter(tmp_path)
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
"LoRA weights file not created"
)
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert torch.allclose(loaded_v, retrieved_v), f"Mismatch in LoRA weight {k}"
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert_tensors_close(loaded_v, retrieved_v, atol=1e-5, rtol=0, msg=f"Mismatch in LoRA weight {k}")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
@@ -122,11 +138,15 @@ class LoraTesterMixin:
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
"Output should differ with LoRA enabled"
)
assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4), (
"Outputs should match before and after save/load"
assert_tensors_close(
outputs_with_lora,
outputs_with_lora_2,
atol=1e-4,
rtol=1e-4,
msg="Outputs should match before and after save/load",
)
def test_lora_wrong_adapter_name_raises_error(self):
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
from peft import LoraConfig
init_dict = self.get_init_dict()
@@ -142,14 +162,13 @@ class LoraTesterMixin:
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
def test_lora_adapter_metadata_is_loaded_correctly(self, rank=4, lora_alpha=4, use_dora=False):
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
init_dict = self.get_init_dict()
@@ -166,19 +185,18 @@ class LoraTesterMixin:
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
def test_lora_adapter_wrong_metadata_raises_error(self):
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
@@ -196,25 +214,337 @@ class LoraTesterMixin:
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
@is_lora
@is_torch_compile
@require_peft_backend
@require_peft_version_greater("0.14.0")
@require_torch_version_greater("2.7.1")
@require_torch_accelerator
class LoraHotSwappingForModelTesterMixin:
"""
Mixin class for testing LoRA hot swapping functionality on models.
Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
and is extensively tested there. The goal of this test is specifically to ensure that
hotswapping with diffusers does not require recompilation.
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- different_shapes_for_compilation: Optional list of (height, width) tuples for dynamic compilation tests
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest marks: lora, torch_compile
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
"""
different_shapes_for_compilation = None
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def teardown_method(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
from peft import LoraConfig
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return lora_config
def _get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def _check_model_hotswap(self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
Check that hotswapping works on a model.
Steps:
- create 2 LoRA adapters and save them
- load the first adapter
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
- optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks:
tol = 5e-3
assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol)
assert not (output0_before == 0).all()
assert not (output1_before == 0).all()
# save the adapter checkpoints
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
del model
# load the first adapter
torch.manual_seed(0)
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
# additionally check if dynamic compilation works.
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output0_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output0_before, output0_after, atol=tol, rtol=tol, msg="Output mismatch after loading adapter0"
)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output1_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output1_before,
output1_after,
atol=tol,
rtol=tol,
msg="Output mismatch after hotswapping to adapter1",
)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with pytest.raises(ValueError, match=re.escape(msg)):
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_model(self, tmp_path, rank0, rank1):
self._check_model_hotswap(
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
target_modules.append(self._get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with pytest.raises(ValueError, match=msg):
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
# check the error and log
import logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
different_shapes_for_compilation = self.different_shapes_for_compilation
if different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
# variable to represent input sizes that are the same. For more details,
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
torch.fx.experimental._config.use_duck_shape = False
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=rank0,
rank1=rank1,
target_modules0=target_modules,
)

View File

@@ -27,6 +27,7 @@ from diffusers.utils.testing_utils import _check_safetensors_serialization
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
@@ -122,8 +123,8 @@ class CPUOffloadTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match with CPU offloading"
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with CPU offloading"
)
@require_offload_support
@@ -156,7 +157,9 @@ class CPUOffloadTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), "Output should match with disk offloading"
assert_tensors_close(
base_output[0], new_output[0], atol=1e-5, rtol=0, msg="Output should match with disk offloading"
)
@require_offload_support
def test_disk_offload_with_safetensors(self):
@@ -183,8 +186,12 @@ class CPUOffloadTesterMixin:
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert torch.allclose(base_output[0], new_output[0], atol=1e-5), (
"Output should match with disk offloading (safetensors)"
assert_tensors_close(
base_output[0],
new_output[0],
atol=1e-5,
rtol=0,
msg="Output should match with disk offloading (safetensors)",
)
@@ -247,17 +254,33 @@ class GroupOffloadTesterMixin:
)
output_with_group_offloading4 = run_forward(model)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5), (
"Output should match with block-level offloading"
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading1,
atol=1e-5,
rtol=0,
msg="Output should match with block-level offloading",
)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5), (
"Output should match with non-blocking block-level offloading"
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading2,
atol=1e-5,
rtol=0,
msg="Output should match with non-blocking block-level offloading",
)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5), (
"Output should match with leaf-level offloading"
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading3,
atol=1e-5,
rtol=0,
msg="Output should match with leaf-level offloading",
)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5), (
"Output should match with leaf-level offloading with stream"
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading4,
atol=1e-5,
rtol=0,
msg="Output should match with leaf-level offloading with stream",
)
@require_group_offload_support
@@ -345,8 +368,12 @@ class GroupOffloadTesterMixin:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol), (
"Output should match with disk-based group offloading"
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading,
atol=atol,
rtol=0,
msg="Output should match with disk-based group offloading",
)

View File

@@ -22,6 +22,7 @@ from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_single_file,
nightly,
@@ -146,9 +147,8 @@ class SingleFileTesterMixin:
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
f"Parameter values differ for {key}: "
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
assert_tensors_close(
param, param_single_file, atol=1e-5, rtol=1e-5, msg=f"Parameter values differ for {key}"
)
def test_single_file_loading_local_files_only(self):

View File

@@ -23,13 +23,14 @@ from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProc
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptTesterMixin,
@@ -94,10 +95,26 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
class FluxTransformerTesterConfig:
model_class = FluxTransformer2DModel
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
pretrained_model_kwargs = {"subfolder": "transformer"}
class FluxTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return FluxTransformer2DModel
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-flux-pipe"
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def generator(self):
@@ -136,14 +153,6 @@ class FluxTransformerTesterConfig:
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):

View File

@@ -131,6 +131,59 @@ def torch_all_close(a, b, *args, **kwargs):
return True
def assert_tensors_close(
actual: "torch.Tensor",
expected: "torch.Tensor",
atol: float = 1e-5,
rtol: float = 1e-5,
msg: str = "",
) -> None:
"""
Assert that two tensors are close within tolerance.
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
Provides concise, actionable error messages without dumping full tensors.
Args:
actual: The actual tensor from the computation.
expected: The expected tensor to compare against.
atol: Absolute tolerance.
rtol: Relative tolerance.
msg: Optional message prefix for the assertion error.
Raises:
AssertionError: If tensors have different shapes or values exceed tolerance.
Example:
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
"""
if not is_torch_available():
raise ValueError("PyTorch needs to be installed to use this function.")
if actual.shape != expected.shape:
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
abs_diff = (actual - expected).abs()
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()
total = actual.numel()
raise AssertionError(
f"{msg}\n"
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
)
def numpy_cosine_similarity_distance(a, b):
similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean()