mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -24,7 +24,7 @@ from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK
|
||||
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from diffusers.models.cache_utils import CacheMixin
|
||||
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
|
||||
from ...testing_utils import backend_empty_cache, is_cache, torch_device
|
||||
|
||||
|
||||
def require_cache_mixin(func):
|
||||
@@ -177,9 +177,9 @@ class CacheTesterMixin:
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(
|
||||
output_without_cache, output_with_cache, atol=1e-5
|
||||
), "Cached output should be different from non-cached output due to cache approximation."
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
def _test_cache_context_manager(self):
|
||||
"""Test the cache_context context manager."""
|
||||
@@ -354,9 +354,9 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(
|
||||
output_without_cache, output_with_cache, atol=1e-5
|
||||
), "Cached output should be different from non-cached output due to cache approximation."
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
|
||||
@@ -487,9 +487,9 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(
|
||||
output_without_cache, output_with_cache, atol=1e-5
|
||||
), "Cached output should be different from non-cached output due to cache approximation."
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FasterCache state."""
|
||||
|
||||
@@ -152,9 +152,9 @@ class QuantizationTesterMixin:
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
num_params_quantized = model_quantized.num_parameters()
|
||||
|
||||
assert (
|
||||
num_params == num_params_quantized
|
||||
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
|
||||
assert num_params == num_params_quantized, (
|
||||
f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
|
||||
)
|
||||
|
||||
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
|
||||
model = self._load_unquantized_model()
|
||||
@@ -164,9 +164,9 @@ class QuantizationTesterMixin:
|
||||
mem_quantized = model_quantized.get_memory_footprint()
|
||||
|
||||
ratio = mem / mem_quantized
|
||||
assert (
|
||||
ratio >= expected_memory_reduction
|
||||
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
|
||||
assert ratio >= expected_memory_reduction, (
|
||||
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
|
||||
)
|
||||
|
||||
def _test_quantization_inference(self, config_kwargs):
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
@@ -260,12 +260,12 @@ class QuantizationTesterMixin:
|
||||
self._verify_if_layer_quantized(name, module, config_kwargs)
|
||||
num_quantized_layers += 1
|
||||
|
||||
assert (
|
||||
num_quantized_layers > 0
|
||||
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
|
||||
assert (
|
||||
num_quantized_layers == expected_quantized_layers
|
||||
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
assert num_quantized_layers > 0, (
|
||||
f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
|
||||
)
|
||||
assert num_quantized_layers == expected_quantized_layers, (
|
||||
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
|
||||
)
|
||||
|
||||
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
|
||||
"""
|
||||
@@ -289,9 +289,9 @@ class QuantizationTesterMixin:
|
||||
if any(excluded in name for excluded in modules_to_not_convert):
|
||||
found_excluded = True
|
||||
# This module should NOT be quantized
|
||||
assert not self._is_module_quantized(
|
||||
module
|
||||
), f"Module {name} should not be quantized but was found to be quantized"
|
||||
assert not self._is_module_quantized(module), (
|
||||
f"Module {name} should not be quantized but was found to be quantized"
|
||||
)
|
||||
|
||||
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
|
||||
|
||||
@@ -313,9 +313,9 @@ class QuantizationTesterMixin:
|
||||
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
|
||||
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
|
||||
|
||||
assert (
|
||||
mem_with_exclusion > mem_fully_quantized
|
||||
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
|
||||
assert mem_with_exclusion > mem_fully_quantized, (
|
||||
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
|
||||
)
|
||||
|
||||
def _test_quantization_device_map(self, config_kwargs):
|
||||
"""
|
||||
@@ -465,9 +465,9 @@ class BitsAndBytesConfigMixin:
|
||||
|
||||
def _verify_if_layer_quantized(self, name, module, config_kwargs):
|
||||
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
|
||||
assert (
|
||||
module.weight.__class__ == expected_weight_class
|
||||
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
|
||||
assert module.weight.__class__ == expected_weight_class, (
|
||||
f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
|
||||
)
|
||||
|
||||
|
||||
@is_bitsandbytes
|
||||
@@ -581,13 +581,13 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert (
|
||||
module.weight.dtype == torch.float32
|
||||
), f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
assert module.weight.dtype == torch.float32, (
|
||||
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
module.weight.dtype == torch.uint8
|
||||
), f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
assert module.weight.dtype == torch.uint8, (
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
Reference in New Issue
Block a user