From 7b3ef42a013d8c3395c4b56103441189964b7643 Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 26 Dec 2025 12:45:30 +0530 Subject: [PATCH] update --- utils/generate_model_tests.py | 60 ++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py index f3860f4b9a..21116ba42f 100644 --- a/utils/generate_model_tests.py +++ b/utils/generate_model_tests.py @@ -53,11 +53,23 @@ ATTENTION_INDICATORS = { } OPTIONAL_TESTERS = [ + # Quantization testers ("BitsAndBytesTesterMixin", "bnb"), ("QuantoTesterMixin", "quanto"), ("TorchAoTesterMixin", "torchao"), ("GGUFTesterMixin", "gguf"), ("ModelOptTesterMixin", "modelopt"), + # Quantization compile testers + ("BitsAndBytesCompileTesterMixin", "bnb_compile"), + ("QuantoCompileTesterMixin", "quanto_compile"), + ("TorchAoCompileTesterMixin", "torchao_compile"), + ("GGUFCompileTesterMixin", "gguf_compile"), + ("ModelOptCompileTesterMixin", "modelopt_compile"), + # Cache testers + ("PyramidAttentionBroadcastTesterMixin", "pab_cache"), + ("FirstBlockCacheTesterMixin", "fbc_cache"), + ("FasterCacheTesterMixin", "faster_cache"), + # Other testers ("SingleFileTesterMixin", "single_file"), ("IPAdapterTesterMixin", "ip_adapter"), ] @@ -339,6 +351,35 @@ def generate_test_class(model_name: str, config_class: str, tester: str) -> str: " return {}", ] ) + elif tester in [ + "BitsAndBytesCompileTesterMixin", + "QuantoCompileTesterMixin", + "TorchAoCompileTesterMixin", + "ModelOptCompileTesterMixin", + ]: + lines.extend( + [ + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization compile tests", + " return {}", + ] + ) + elif tester == "GGUFCompileTesterMixin": + lines.extend( + [ + ' gguf_filename = "" # TODO: Set GGUF filename', + "", + " def get_dummy_inputs(self) -> dict[str, torch.Tensor]:", + " # TODO: Override with larger inputs for quantization compile tests", + " return {}", + ] + ) + elif tester in [ + "PyramidAttentionBroadcastTesterMixin", + "FirstBlockCacheTesterMixin", + "FasterCacheTesterMixin", + ]: + lines.append(" pass") elif tester == "LoraHotSwappingForModelTesterMixin": lines.extend( [ @@ -448,7 +489,24 @@ def main(): type=str, nargs="*", default=[], - choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"], + choices=[ + "bnb", + "quanto", + "torchao", + "gguf", + "modelopt", + "bnb_compile", + "quanto_compile", + "torchao_compile", + "gguf_compile", + "modelopt_compile", + "pab_cache", + "fbc_cache", + "faster_cache", + "single_file", + "ip_adapter", + "all", + ], help="Optional testers to include", ) parser.add_argument(