1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-12-18 13:18:54 +05:30
parent e82001e40d
commit c70de2bc37
2 changed files with 37 additions and 37 deletions

View File

@@ -24,7 +24,7 @@ from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
from ...testing_utils import backend_empty_cache, is_cache, torch_device
def require_cache_mixin(func):
@@ -177,9 +177,9 @@ class CacheTesterMixin:
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(
output_without_cache, output_with_cache, atol=1e-5
), "Cached output should be different from non-cached output due to cache approximation."
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
def _test_cache_context_manager(self):
"""Test the cache_context context manager."""
@@ -354,9 +354,9 @@ class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(
output_without_cache, output_with_cache, atol=1e-5
), "Cached output should be different from non-cached output due to cache approximation."
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
@@ -487,9 +487,9 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(
output_without_cache, output_with_cache, atol=1e-5
), "Cached output should be different from non-cached output due to cache approximation."
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FasterCache state."""

View File

@@ -152,9 +152,9 @@ class QuantizationTesterMixin:
model_quantized = self._create_quantized_model(config_kwargs)
num_params_quantized = model_quantized.num_parameters()
assert (
num_params == num_params_quantized
), f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
assert num_params == num_params_quantized, (
f"Parameter count mismatch: unquantized={num_params}, quantized={num_params_quantized}"
)
def _test_quantization_memory_footprint(self, config_kwargs, expected_memory_reduction=1.2):
model = self._load_unquantized_model()
@@ -164,9 +164,9 @@ class QuantizationTesterMixin:
mem_quantized = model_quantized.get_memory_footprint()
ratio = mem / mem_quantized
assert (
ratio >= expected_memory_reduction
), f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
assert ratio >= expected_memory_reduction, (
f"Memory ratio {ratio:.2f} is less than expected ({expected_memory_reduction}x). unquantized={mem}, quantized={mem_quantized}"
)
def _test_quantization_inference(self, config_kwargs):
model_quantized = self._create_quantized_model(config_kwargs)
@@ -260,12 +260,12 @@ class QuantizationTesterMixin:
self._verify_if_layer_quantized(name, module, config_kwargs)
num_quantized_layers += 1
assert (
num_quantized_layers > 0
), f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
assert (
num_quantized_layers == expected_quantized_layers
), f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
assert num_quantized_layers > 0, (
f"No quantized layers found in model (expected {expected_quantized_layers} linear layers, {num_fp32_modules} kept in FP32)"
)
assert num_quantized_layers == expected_quantized_layers, (
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
)
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
"""
@@ -289,9 +289,9 @@ class QuantizationTesterMixin:
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(
module
), f"Module {name} should not be quantized but was found to be quantized"
assert not self._is_module_quantized(module), (
f"Module {name} should not be quantized but was found to be quantized"
)
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_not_convert}"
@@ -313,9 +313,9 @@ class QuantizationTesterMixin:
mem_with_exclusion = model_with_exclusion.get_memory_footprint()
mem_fully_quantized = model_fully_quantized.get_memory_footprint()
assert (
mem_with_exclusion > mem_fully_quantized
), f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
assert mem_with_exclusion > mem_fully_quantized, (
f"Model with exclusions should be larger. With exclusion: {mem_with_exclusion}, fully quantized: {mem_fully_quantized}"
)
def _test_quantization_device_map(self, config_kwargs):
"""
@@ -465,9 +465,9 @@ class BitsAndBytesConfigMixin:
def _verify_if_layer_quantized(self, name, module, config_kwargs):
expected_weight_class = bnb.nn.Params4bit if config_kwargs.get("load_in_4bit") else bnb.nn.Int8Params
assert (
module.weight.__class__ == expected_weight_class
), f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
assert module.weight.__class__ == expected_weight_class, (
f"Layer {name} has weight type {module.weight.__class__}, expected {expected_weight_class}"
)
@is_bitsandbytes
@@ -581,13 +581,13 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert (
module.weight.dtype == torch.float32
), f"Module {name} should be FP32 but is {module.weight.dtype}"
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert (
module.weight.dtype == torch.uint8
), f"Module {name} should be uint8 but is {module.weight.dtype}"
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
with torch.no_grad():
inputs = self.get_dummy_inputs()