1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2026-01-07 12:26:44 +05:30
parent e0ab03d79b
commit ba475eee8d
7 changed files with 81 additions and 112 deletions

View File

@@ -40,7 +40,6 @@ class AttentionTesterMixin:
Expected class attributes to be set by subclasses:
- model_class: The model class to test
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
Expected methods to be implemented by subclasses:
@@ -51,8 +50,7 @@ class AttentionTesterMixin:
Use `pytest -m "not attention"` to skip these tests
"""
base_precision = 1e-3
@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
@@ -63,14 +61,10 @@ class AttentionTesterMixin:
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
# Get output before fusion
with torch.no_grad():
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
# Fuse projections
model.fuse_qkv_projections()
# Verify fusion occurred by checking for fused attributes
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
@@ -80,38 +74,30 @@ class AttentionTesterMixin:
break
if has_fused_projections:
# Get output after fusion
with torch.no_grad():
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
# Verify outputs match
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=self.base_precision,
atol=1e-3,
rtol=0,
msg="Output should not change after fusing projections",
)
# Unfuse projections
model.unfuse_qkv_projections()
# Verify unfusion occurred
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
# Get output after unfusion
with torch.no_grad():
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
# Verify outputs still match
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=self.base_precision,
atol=1e-3,
rtol=0,
msg="Output should match original after unfusing projections",
)

View File

@@ -196,6 +196,7 @@ class CacheTesterMixin:
model.disable_cache()
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the cache state."""
init_dict = self.get_init_dict()
@@ -207,11 +208,8 @@ class CacheTesterMixin:
model.enable_cache(config)
# Run forward to populate cache state
with torch.no_grad():
_ = model(**inputs_dict, return_dict=False)[0]
_ = model(**inputs_dict, return_dict=False)[0]
# Reset should not raise any errors
model._reset_stateful_cache()
model.disable_cache()
@@ -358,6 +356,7 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
init_dict = self.get_init_dict()
@@ -368,12 +367,9 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
config = self._get_cache_config()
model.enable_cache(config)
# FBC requires cache_context to be set for inference
with model.cache_context("fbc_test"):
with torch.no_grad():
_ = model(**inputs_dict, return_dict=False)[0]
_ = model(**inputs_dict, return_dict=False)[0]
# Reset should not raise any errors
model._reset_stateful_cache()
model.disable_cache()
@@ -491,6 +487,7 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FasterCache state."""
init_dict = self.get_init_dict()
@@ -501,12 +498,9 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
config = self._get_cache_config()
model.enable_cache(config)
# First pass with timestep outside skip range
self._current_timestep[0] = 1000
with torch.no_grad():
_ = model(**inputs_dict, return_dict=False)[0]
_ = model(**inputs_dict, return_dict=False)[0]
# Reset should not raise any errors
model._reset_stateful_cache()
model.disable_cache()

View File

@@ -145,7 +145,6 @@ class BaseModelTesterConfig:
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
- output_shape: Expected output shape for output validation tests (default: None)
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
Required methods (must be implemented by subclasses):
@@ -259,6 +258,7 @@ class ModelTesterMixin:
pass
"""
@torch.no_grad()
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
@@ -269,7 +269,6 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)
# check if all parameters shape are the same
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
@@ -277,12 +276,12 @@ class ModelTesterMixin:
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
with torch.no_grad():
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
@@ -291,18 +290,15 @@ class ModelTesterMixin:
model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
# non-variant cannot be loaded
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)
# make sure that error message states what keys are missing
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
with torch.no_grad():
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -324,16 +320,15 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
@torch.no_grad()
def test_determinism(self, atol=1e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with torch.no_grad():
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
# Filter out NaN values before comparison
first_flat = first.flatten()
second_flat = second.flatten()
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
@@ -344,24 +339,23 @@ class ModelTesterMixin:
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)
@torch.no_grad()
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
with torch.no_grad():
output = model(**inputs_dict, return_dict=False)[0]
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
)
@torch.no_grad()
def test_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
# Track progress in https://github.com/pytorch/pytorch/issues/77764
device = t.device
if device.type == "mps":
t = t.to("cpu")
@@ -390,9 +384,8 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
@@ -465,6 +458,7 @@ class ModelTesterMixin:
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
@@ -479,9 +473,8 @@ class ModelTesterMixin:
else:
assert param.data.dtype == dtype
with torch.no_grad():
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")

View File

@@ -98,6 +98,7 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
@@ -120,10 +121,10 @@ class TorchCompileTesterMixin:
model.enable_group_offload(**group_offload_kwargs)
model.compile()
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
@@ -135,10 +136,11 @@ class TorchCompileTesterMixin:
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
with torch._dynamo.config.patch(error_on_recompile=True):
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_works_with_aot(self, tmp_path):
from torch._inductor.package import load_package
@@ -155,6 +157,5 @@ class TorchCompileTesterMixin:
model.forward = loaded_binary
with torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
_ = model(**inputs_dict)
_ = model(**inputs_dict)

View File

@@ -208,7 +208,8 @@ class GroupOffloadTesterMixin:
"""
@require_group_offload_support
def test_group_offloading(self, record_stream=False):
@pytest.mark.parametrize("record_stream", [False, True])
def test_group_offloading(self, record_stream):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
@@ -280,8 +281,10 @@ class GroupOffloadTesterMixin:
)
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
@@ -306,9 +309,11 @@ class GroupOffloadTesterMixin:
_ = model(**inputs_dict)[0]
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5):
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters

View File

@@ -27,9 +27,9 @@ from ...testing_utils import (
)
@torch.no_grad()
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
try:
# Setup distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
@@ -56,8 +56,7 @@ def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict,
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
result_queue.put(("success", output.shape))
@@ -73,8 +72,6 @@ def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict,
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
base_precision = 1e-3
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():

View File

@@ -168,15 +168,15 @@ class QuantizationTesterMixin:
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
)
@torch.no_grad()
def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_quantized(**inputs, return_dict=False)[0]
inputs = self.get_dummy_inputs()
output = model_quantized(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
def _test_quantization_dtype_assignment(self, config_kwargs):
model = self._create_quantized_model(config_kwargs)
@@ -196,6 +196,7 @@ class QuantizationTesterMixin:
model.to(torch_device)
@torch.no_grad()
def _test_quantization_lora_inference(self, config_kwargs):
try:
from peft import LoraConfig
@@ -217,13 +218,13 @@ class QuantizationTesterMixin:
)
model.add_adapter(lora_config)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None with LoRA"
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
assert output is not None, "Model output is None with LoRA"
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
@torch.no_grad()
def _test_quantization_serialization(self, config_kwargs, tmp_path):
model = self._create_quantized_model(config_kwargs)
@@ -231,10 +232,9 @@ class QuantizationTesterMixin:
model_loaded = self.model_class.from_pretrained(str(tmp_path))
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
inputs = self.get_dummy_inputs()
output = model_loaded(**inputs, return_dict=False)[0]
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
def _test_quantized_layers(self, config_kwargs):
model_fp = self._load_unquantized_model()
@@ -317,6 +317,7 @@ class QuantizationTesterMixin:
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
)
@torch.no_grad()
def _test_quantization_device_map(self, config_kwargs):
"""
Test that quantized models work correctly with device_map="auto".
@@ -326,17 +327,15 @@ class QuantizationTesterMixin:
"""
model = self._create_quantized_model(config_kwargs, device_map="auto")
# Verify device map is set
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
assert model.hf_device_map is not None, "hf_device_map should not be None"
# Verify inference works
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
@torch.no_grad()
def _test_dequantize(self, config_kwargs):
"""
Test that dequantize() converts quantized model back to standard linear layers.
@@ -346,24 +345,19 @@ class QuantizationTesterMixin:
"""
model = self._create_quantized_model(config_kwargs)
# Verify model has dequantize method
if not hasattr(model, "dequantize"):
pytest.skip("Model does not have dequantize method")
# Dequantize the model
model.dequantize()
# Verify no modules are quantized after dequantization
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
# Verify inference still works after dequantization
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
def _test_quantization_training(self, config_kwargs):
"""
@@ -566,6 +560,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
torch.bfloat16,
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
@torch.no_grad()
def test_bnb_keep_modules_in_fp32(self):
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
@@ -589,9 +584,8 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
with torch.no_grad():
inputs = self.get_dummy_inputs()
_ = model(**inputs)
inputs = self.get_dummy_inputs()
_ = model(**inputs)
finally:
if original_fp32_modules is not None:
self.model_class._keep_in_fp32_modules = original_fp32_modules
@@ -1097,6 +1091,7 @@ class QuantizationCompileTesterMixin:
backend_empty_cache(torch_device)
torch.compiler.reset()
@torch.no_grad()
def _test_torch_compile(self, config_kwargs):
"""
Test that torch.compile works correctly with a quantized model.
@@ -1108,16 +1103,15 @@ class QuantizationCompileTesterMixin:
model.to(torch_device)
model.eval()
# Compile the model with fullgraph=True to ensure no graph breaks
model = torch.compile(model, fullgraph=True)
# Run inference with error_on_recompile to detect recompilation issues
with torch.no_grad(), torch._dynamo.config.patch(error_on_recompile=True):
with torch._dynamo.config.patch(error_on_recompile=True):
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
@torch.no_grad()
def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False):
"""
Test that torch.compile works correctly with a quantized model and group offloading.
@@ -1143,11 +1137,10 @@ class QuantizationCompileTesterMixin:
model.enable_group_offload(**group_offload_kwargs)
model = torch.compile(model)
with torch.no_grad():
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
assert not torch.isnan(output).any(), "Model output contains NaN"
@is_bitsandbytes