1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remave onnx

This commit is contained in:
Patrick von Platen
2022-06-17 11:00:01 +02:00
parent 5e6f500038
commit e660a05fed

View File

@@ -177,56 +177,9 @@ def check_model_table(overwrite=False):
)
def has_onnx(model_type):
"""
Returns whether `model_type` is supported by ONNX (by checking if there is an ONNX config) or not.
"""
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
if model_type not in config_mapping:
return False
config = config_mapping[model_type]
config_module = config.__module__
module = transformers_module
for part in config_module.split(".")[1:]:
module = getattr(module, part)
config_name = config.__name__
onnx_config_name = config_name.replace("ConfigMixin", "OnnxConfigMixin")
return hasattr(module, onnx_config_name)
def get_onnx_model_list():
"""
Return the list of models supporting ONNX.
"""
config_mapping = transformers_module.models.auto.configuration_auto.CONFIG_MAPPING
model_names = config_mapping = transformers_module.models.auto.configuration_auto.MODEL_NAMES_MAPPING
onnx_model_types = [model_type for model_type in config_mapping.keys() if has_onnx(model_type)]
onnx_model_names = [model_names[model_type] for model_type in onnx_model_types]
onnx_model_names.sort(key=lambda x: x.upper())
return "\n".join([f"- {name}" for name in onnx_model_names]) + "\n"
def check_onnx_model_list(overwrite=False):
"""Check the model list in the serialization.mdx is consistent with the state of the lib and maybe `overwrite`."""
current_list, start_index, end_index, lines = _find_text_in_file(
filename=os.path.join(PATH_TO_DOCS, "serialization.mdx"),
start_prompt="<!--This table is automatically generated by `make fix-copies`, do not fill manually!-->",
end_prompt="In the next two sections, we'll show you how to:",
)
new_list = get_onnx_model_list()
if current_list != new_list:
if overwrite:
with open(os.path.join(PATH_TO_DOCS, "serialization.mdx"), "w", encoding="utf-8", newline="\n") as f:
f.writelines(lines[:start_index] + [new_list] + lines[end_index:])
else:
raise ValueError("The list of ONNX-supported models needs an update. Run `make fix-copies` to fix this.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_model_table(args.fix_and_overwrite)
check_onnx_model_list(args.fix_and_overwrite)