1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2025-12-26 12:45:30 +05:30
parent c70de2bc37
commit 7b3ef42a01

View File

@@ -53,11 +53,23 @@ ATTENTION_INDICATORS = {
}
OPTIONAL_TESTERS = [
# Quantization testers
("BitsAndBytesTesterMixin", "bnb"),
("QuantoTesterMixin", "quanto"),
("TorchAoTesterMixin", "torchao"),
("GGUFTesterMixin", "gguf"),
("ModelOptTesterMixin", "modelopt"),
# Quantization compile testers
("BitsAndBytesCompileTesterMixin", "bnb_compile"),
("QuantoCompileTesterMixin", "quanto_compile"),
("TorchAoCompileTesterMixin", "torchao_compile"),
("GGUFCompileTesterMixin", "gguf_compile"),
("ModelOptCompileTesterMixin", "modelopt_compile"),
# Cache testers
("PyramidAttentionBroadcastTesterMixin", "pab_cache"),
("FirstBlockCacheTesterMixin", "fbc_cache"),
("FasterCacheTesterMixin", "faster_cache"),
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
]
@@ -339,6 +351,35 @@ def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
" return {}",
]
)
elif tester in [
"BitsAndBytesCompileTesterMixin",
"QuantoCompileTesterMixin",
"TorchAoCompileTesterMixin",
"ModelOptCompileTesterMixin",
]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester == "GGUFCompileTesterMixin":
lines.extend(
[
' gguf_filename = "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester in [
"PyramidAttentionBroadcastTesterMixin",
"FirstBlockCacheTesterMixin",
"FasterCacheTesterMixin",
]:
lines.append(" pass")
elif tester == "LoraHotSwappingForModelTesterMixin":
lines.extend(
[
@@ -448,7 +489,24 @@ def main():
type=str,
nargs="*",
default=[],
choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"],
choices=[
"bnb",
"quanto",
"torchao",
"gguf",
"modelopt",
"bnb_compile",
"quanto_compile",
"torchao_compile",
"gguf_compile",
"modelopt_compile",
"pab_cache",
"fbc_cache",
"faster_cache",
"single_file",
"ip_adapter",
"all",
],
help="Optional testers to include",
)
parser.add_argument(