mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -40,7 +40,6 @@ class AttentionTesterMixin:
|
||||
|
||||
Expected class attributes to be set by subclasses:
|
||||
- model_class: The model class to test
|
||||
- base_precision: Tolerance for floating point comparisons (default: 1e-3)
|
||||
- uses_custom_attn_processor: Whether model uses custom attention processors (default: False)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
@@ -51,8 +50,7 @@ class AttentionTesterMixin:
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
base_precision = 1e-3
|
||||
|
||||
@torch.no_grad()
|
||||
def test_fuse_unfuse_qkv_projections(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
@@ -63,14 +61,10 @@ class AttentionTesterMixin:
|
||||
if not hasattr(model, "fuse_qkv_projections"):
|
||||
pytest.skip("Model does not support QKV projection fusion.")
|
||||
|
||||
# Get output before fusion
|
||||
with torch.no_grad():
|
||||
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Fuse projections
|
||||
model.fuse_qkv_projections()
|
||||
|
||||
# Verify fusion occurred by checking for fused attributes
|
||||
has_fused_projections = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
@@ -80,38 +74,30 @@ class AttentionTesterMixin:
|
||||
break
|
||||
|
||||
if has_fused_projections:
|
||||
# Get output after fusion
|
||||
with torch.no_grad():
|
||||
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Verify outputs match
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_fusion,
|
||||
atol=self.base_precision,
|
||||
atol=1e-3,
|
||||
rtol=0,
|
||||
msg="Output should not change after fusing projections",
|
||||
)
|
||||
|
||||
# Unfuse projections
|
||||
model.unfuse_qkv_projections()
|
||||
|
||||
# Verify unfusion occurred
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||
assert not module.fused_projections, "fused_projections flag should be False"
|
||||
|
||||
# Get output after unfusion
|
||||
with torch.no_grad():
|
||||
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
|
||||
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Verify outputs still match
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_unfusion,
|
||||
atol=self.base_precision,
|
||||
atol=1e-3,
|
||||
rtol=0,
|
||||
msg="Output should match original after unfusing projections",
|
||||
)
|
||||
|
||||
@@ -196,6 +196,7 @@ class CacheTesterMixin:
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
@@ -207,11 +208,8 @@ class CacheTesterMixin:
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Run forward to populate cache state
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
@@ -358,6 +356,7 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
@@ -368,12 +367,9 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# FBC requires cache_context to be set for inference
|
||||
with model.cache_context("fbc_test"):
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
@@ -491,6 +487,7 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FasterCache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
@@ -501,12 +498,9 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range
|
||||
self._current_timestep[0] = 1000
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Reset should not raise any errors
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@@ -145,7 +145,6 @@ class BaseModelTesterConfig:
|
||||
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
|
||||
- output_shape: Expected output shape for output validation tests (default: None)
|
||||
- base_precision: Default tolerance for floating point comparisons (default: 1e-3)
|
||||
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
@@ -259,6 +258,7 @@ class ModelTesterMixin:
|
||||
pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
@@ -269,7 +269,6 @@ class ModelTesterMixin:
|
||||
new_model = self.model_class.from_pretrained(tmp_path)
|
||||
new_model.to(torch_device)
|
||||
|
||||
# check if all parameters shape are the same
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
@@ -277,12 +276,12 @@ class ModelTesterMixin:
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
@@ -291,18 +290,15 @@ class ModelTesterMixin:
|
||||
model.save_pretrained(tmp_path, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
|
||||
|
||||
# non-variant cannot be loaded
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmp_path)
|
||||
|
||||
# make sure that error message states what keys are missing
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -324,16 +320,15 @@ class ModelTesterMixin:
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
@torch.no_grad()
|
||||
def test_determinism(self, atol=1e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
# Filter out NaN values before comparison
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
@@ -344,24 +339,23 @@ class ModelTesterMixin:
|
||||
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_output(self, expected_output_shape=None):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert output[0].shape == expected_output_shape or self.output_shape, (
|
||||
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
@@ -390,9 +384,8 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@@ -465,6 +458,7 @@ class ModelTesterMixin:
|
||||
reason="float16 and bfloat16 can only be use for inference with an accelerator",
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
@@ -479,9 +473,8 @@ class ModelTesterMixin:
|
||||
else:
|
||||
assert param.data.dtype == dtype
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
output = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
output_loaded = model_loaded(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(output, output_loaded, atol=1e-4, rtol=0, msg=f"Loaded model output differs for {dtype}")
|
||||
|
||||
|
||||
@@ -98,6 +98,7 @@ class TorchCompileTesterMixin:
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_with_group_offloading(self):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
@@ -120,10 +121,10 @@ class TorchCompileTesterMixin:
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
@@ -135,10 +136,11 @@ class TorchCompileTesterMixin:
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_works_with_aot(self, tmp_path):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
@@ -155,6 +157,5 @@ class TorchCompileTesterMixin:
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@@ -208,7 +208,8 @@ class GroupOffloadTesterMixin:
|
||||
"""
|
||||
|
||||
@require_group_offload_support
|
||||
def test_group_offloading(self, record_stream=False):
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
def test_group_offloading(self, record_stream):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
@@ -280,8 +281,10 @@ class GroupOffloadTesterMixin:
|
||||
)
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream=False, offload_type="block_level"):
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
@@ -306,9 +309,11 @@ class GroupOffloadTesterMixin:
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, tmp_path, offload_type="block_level", record_stream=False, atol=1e-5):
|
||||
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
|
||||
@@ -27,9 +27,9 @@ from ...testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue):
|
||||
try:
|
||||
# Setup distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "12355"
|
||||
|
||||
@@ -56,8 +56,7 @@ def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict,
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
if rank == 0:
|
||||
result_queue.put(("success", output.shape))
|
||||
@@ -73,8 +72,6 @@ def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict,
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
base_precision = 1e-3
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_inference(self, cp_type):
|
||||
if not torch.distributed.is_available():
|
||||
|
||||
@@ -168,15 +168,15 @@ class QuantizationTesterMixin:
|
||||
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_quantization_inference(self, config_kwargs):
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_quantized(**inputs, return_dict=False)[0]
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_quantized(**inputs, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
def _test_quantization_dtype_assignment(self, config_kwargs):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
@@ -196,6 +196,7 @@ class QuantizationTesterMixin:
|
||||
|
||||
model.to(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_quantization_lora_inference(self, config_kwargs):
|
||||
try:
|
||||
from peft import LoraConfig
|
||||
@@ -217,13 +218,13 @@ class QuantizationTesterMixin:
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None with LoRA"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
|
||||
assert output is not None, "Model output is None with LoRA"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN with LoRA"
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_quantization_serialization(self, config_kwargs, tmp_path):
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
@@ -231,10 +232,9 @@ class QuantizationTesterMixin:
|
||||
|
||||
model_loaded = self.model_class.from_pretrained(str(tmp_path))
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_loaded(**inputs, return_dict=False)[0]
|
||||
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model_loaded(**inputs, return_dict=False)[0]
|
||||
assert not torch.isnan(output).any(), "Loaded model output contains NaN"
|
||||
|
||||
def _test_quantized_layers(self, config_kwargs):
|
||||
model_fp = self._load_unquantized_model()
|
||||
@@ -317,6 +317,7 @@ class QuantizationTesterMixin:
|
||||
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_quantization_device_map(self, config_kwargs):
|
||||
"""
|
||||
Test that quantized models work correctly with device_map="auto".
|
||||
@@ -326,17 +327,15 @@ class QuantizationTesterMixin:
|
||||
"""
|
||||
model = self._create_quantized_model(config_kwargs, device_map="auto")
|
||||
|
||||
# Verify device map is set
|
||||
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute"
|
||||
assert model.hf_device_map is not None, "hf_device_map should not be None"
|
||||
|
||||
# Verify inference works
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_dequantize(self, config_kwargs):
|
||||
"""
|
||||
Test that dequantize() converts quantized model back to standard linear layers.
|
||||
@@ -346,24 +345,19 @@ class QuantizationTesterMixin:
|
||||
"""
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
|
||||
# Verify model has dequantize method
|
||||
if not hasattr(model, "dequantize"):
|
||||
pytest.skip("Model does not have dequantize method")
|
||||
|
||||
# Dequantize the model
|
||||
model.dequantize()
|
||||
|
||||
# Verify no modules are quantized after dequantization
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
|
||||
|
||||
# Verify inference still works after dequantization
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None after dequantization"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None after dequantization"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
|
||||
|
||||
def _test_quantization_training(self, config_kwargs):
|
||||
"""
|
||||
@@ -566,6 +560,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
torch.bfloat16,
|
||||
], f"Unexpected dtype: {model.config['_pre_quantization_dtype']}"
|
||||
|
||||
@torch.no_grad()
|
||||
def test_bnb_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
@@ -589,9 +584,8 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
finally:
|
||||
if original_fp32_modules is not None:
|
||||
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
||||
@@ -1097,6 +1091,7 @@ class QuantizationCompileTesterMixin:
|
||||
backend_empty_cache(torch_device)
|
||||
torch.compiler.reset()
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_torch_compile(self, config_kwargs):
|
||||
"""
|
||||
Test that torch.compile works correctly with a quantized model.
|
||||
@@ -1108,16 +1103,15 @@ class QuantizationCompileTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# Compile the model with fullgraph=True to ensure no graph breaks
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
# Run inference with error_on_recompile to detect recompilation issues
|
||||
with torch.no_grad(), torch._dynamo.config.patch(error_on_recompile=True):
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_torch_compile_with_group_offload(self, config_kwargs, use_stream=False):
|
||||
"""
|
||||
Test that torch.compile works correctly with a quantized model and group offloading.
|
||||
@@ -1143,11 +1137,10 @@ class QuantizationCompileTesterMixin:
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model = torch.compile(model)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN"
|
||||
|
||||
|
||||
@is_bitsandbytes
|
||||
|
||||
Reference in New Issue
Block a user