From e660a05fed40229a747ce63104c867ab143a0db4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Jun 2022 11:00:01 +0200 Subject: [PATCH] remave onnx --- utils/check_table.py | 47 -------------------------------------------- 1 file changed, 47 deletions(-) diff --git a/utils/check_table.py b/utils/check_table.py index 08b5a23ca9..3a900551a4 100644 --- a/utils/check_table.py +++ b/utils/check_table.py @@ -177,56 +177,9 @@ def check_model_table(overwrite=False): ) -def has_onnx(model_type): - """ - Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not. - """ - config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING - if model_type not in config_mapping: - return False - config = config_mapping[model_type] - config_module = config.__module__ - module = transformers_module - for part in config_module.split(".")[1:]: - module = getattr(module, part) - config_name = config.__name__ - onnx_config_name = config_name.replace("ConfigMixin", "OnnxConfigMixin") - return hasattr(module, onnx_config_name) - - -def get_onnx_model_list(): - """ - Return the list of models supporting ONNX. - """ - config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING - model_names = config_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING - onnx_model_types = [model_type for model_type in config_mapping.keys() if has_onnx(model_type)] - onnx_model_names = [model_names[model_type] for model_type in onnx_model_types] - onnx_model_names.sort(key=lambda x: x.upper()) - return "\n".join([f"- {name}" for name in onnx_model_names]) + "\n" - - -def check_onnx_model_list(overwrite=False): - """Check the model list in the serialization.mdx is consistent with the state of the lib and maybe `overwrite`.""" - current_list, start_index, end_index, lines = _find_text_in_file( - filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"), - start_prompt="", - end_prompt="In the next two sections, we'll show you how to:", - ) - new_list = get_onnx_model_list() - - if current_list != new_list: - if overwrite: - with open(os.path.join(PATH_TO_DOCS, "serialization.mdx"), "w", encoding="utf-8", newline="\n") as f: - f.writelines(lines[:start_index] + [new_list] + lines[end_index:]) - else: - raise ValueError("The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this.") - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") args = parser.parse_args() check_model_table(args.fix_and_overwrite) - check_onnx_model_list(args.fix_and_overwrite)