mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -53,11 +53,23 @@ ATTENTION_INDICATORS = {
|
||||
}
|
||||
|
||||
OPTIONAL_TESTERS = [
|
||||
# Quantization testers
|
||||
("BitsAndBytesTesterMixin", "bnb"),
|
||||
("QuantoTesterMixin", "quanto"),
|
||||
("TorchAoTesterMixin", "torchao"),
|
||||
("GGUFTesterMixin", "gguf"),
|
||||
("ModelOptTesterMixin", "modelopt"),
|
||||
# Quantization compile testers
|
||||
("BitsAndBytesCompileTesterMixin", "bnb_compile"),
|
||||
("QuantoCompileTesterMixin", "quanto_compile"),
|
||||
("TorchAoCompileTesterMixin", "torchao_compile"),
|
||||
("GGUFCompileTesterMixin", "gguf_compile"),
|
||||
("ModelOptCompileTesterMixin", "modelopt_compile"),
|
||||
# Cache testers
|
||||
("PyramidAttentionBroadcastTesterMixin", "pab_cache"),
|
||||
("FirstBlockCacheTesterMixin", "fbc_cache"),
|
||||
("FasterCacheTesterMixin", "faster_cache"),
|
||||
# Other testers
|
||||
("SingleFileTesterMixin", "single_file"),
|
||||
("IPAdapterTesterMixin", "ip_adapter"),
|
||||
]
|
||||
@@ -339,6 +351,35 @@ def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester in [
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
"QuantoCompileTesterMixin",
|
||||
"TorchAoCompileTesterMixin",
|
||||
"ModelOptCompileTesterMixin",
|
||||
]:
|
||||
lines.extend(
|
||||
[
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization compile tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester == "GGUFCompileTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
' gguf_filename = "" # TODO: Set GGUF filename',
|
||||
"",
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization compile tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester in [
|
||||
"PyramidAttentionBroadcastTesterMixin",
|
||||
"FirstBlockCacheTesterMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
]:
|
||||
lines.append(" pass")
|
||||
elif tester == "LoraHotSwappingForModelTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
@@ -448,7 +489,24 @@ def main():
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
choices=["compile", "bnb", "quanto", "torchao", "gguf", "modelopt", "single_file", "ip_adapter", "all"],
|
||||
choices=[
|
||||
"bnb",
|
||||
"quanto",
|
||||
"torchao",
|
||||
"gguf",
|
||||
"modelopt",
|
||||
"bnb_compile",
|
||||
"quanto_compile",
|
||||
"torchao_compile",
|
||||
"gguf_compile",
|
||||
"modelopt_compile",
|
||||
"pab_cache",
|
||||
"fbc_cache",
|
||||
"faster_cache",
|
||||
"single_file",
|
||||
"ip_adapter",
|
||||
"all",
|
||||
],
|
||||
help="Optional testers to include",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user