1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add is_torch_available, is_flax_available (#204)

* Add is_<framework>_available, refactor import utils

* deps

* quality
This commit is contained in:
Anton Lozhkov
2022-08-17 16:47:20 +02:00
committed by GitHub
parent ed22b4fd07
commit df90f0ce98
7 changed files with 312 additions and 185 deletions

View File

@@ -22,14 +22,13 @@ from collections import OrderedDict
from difflib import get_close_matches
from pathlib import Path
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto import get_values
from transformers.utils import ENV_VARS_TRUE_VALUES
from diffusers.models.auto import get_values
from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_tf_available, is_torch_available
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_repo.py
PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_DIFFUSERS = "src/diffusers"
PATH_TO_TESTS = "tests"
PATH_TO_DOC = "docs/source/en"
@@ -200,17 +199,17 @@ MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
# This is to make sure the transformers module imported is the one in the repo.
spec = importlib.util.spec_from_file_location(
"transformers",
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
submodule_search_locations=[PATH_TO_TRANSFORMERS],
"diffusers",
os.path.join(PATH_TO_DIFFUSERS, "__init__.py"),
submodule_search_locations=[PATH_TO_DIFFUSERS],
)
transformers = spec.loader.load_module()
diffusers = spec.loader.load_module()
def check_model_list():
"""Check the model list inside the transformers library."""
# Get the models from the directory structure of `src/transformers/models/`
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
# Get the models from the directory structure of `src/diffusers/models/`
models_dir = os.path.join(PATH_TO_DIFFUSERS, "models")
_models = []
for model in os.listdir(models_dir):
model_dir = os.path.join(models_dir, model)
@@ -218,7 +217,7 @@ def check_model_list():
_models.append(model)
# Get the models from the directory structure of `src/transformers/models/`
models = [model for model in dir(transformers.models) if not model.startswith("__")]
models = [model for model in dir(diffusers.models) if not model.startswith("__")]
missing_models = sorted(list(set(_models).difference(models)))
if missing_models:
@@ -256,10 +255,10 @@ def get_model_modules():
"modeling_vision_encoder_decoder",
]
modules = []
for model in dir(transformers.models):
for model in dir(diffusers.models):
# There are some magic dunder attributes in the dir, we ignore them
if not model.startswith("__"):
model_module = getattr(transformers.models, model)
model_module = getattr(diffusers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
@@ -271,7 +270,7 @@ def get_model_modules():
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
models = []
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
model_classes = (diffusers.ModelMixin, diffusers.TFModelMixin, diffusers.FlaxModelMixin)
for attr_name in dir(module):
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue
@@ -299,7 +298,7 @@ def is_a_private_model(model):
def check_models_are_in_init():
"""Checks all models defined in the library are in the main init."""
models_not_in_init = []
dir_transformers = dir(transformers)
dir_transformers = dir(diffusers)
for module in get_model_modules():
models_not_in_init += [
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
@@ -419,17 +418,17 @@ def get_all_auto_configured_models():
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
for attr_name in dir(transformers.models.auto.modeling_auto):
for attr_name in dir(diffusers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
result = result | set(get_values(getattr(diffusers.models.auto.modeling_auto, attr_name)))
if is_tf_available():
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
for attr_name in dir(diffusers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
result = result | set(get_values(getattr(diffusers.models.auto.modeling_tf_auto, attr_name)))
if is_flax_available():
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
for attr_name in dir(diffusers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name)))
return [cls for cls in result]
@@ -636,8 +635,8 @@ def ignore_undocumented(name):
):
return True
# Submodules are not documented.
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile(
os.path.join(PATH_TO_DIFFUSERS, f"{name}.py")
):
return True
# All load functions are not documented.
@@ -660,8 +659,8 @@ def ignore_undocumented(name):
def check_all_objects_are_documented():
"""Check all models are properly documented."""
documented_objs = find_all_documented_objects()
modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
modules = diffusers._modules
objects = [c for c in dir(diffusers) if c not in modules and not c.startswith("_")]
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
if len(undocumented_objs) > 0:
raise Exception(
@@ -677,7 +676,7 @@ def check_model_type_doc_match():
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
errors = []
@@ -723,7 +722,7 @@ def is_rst_docstring(docstring):
def check_docstrings_are_in_md():
"""Check all docstrings are in md"""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"):
with open(file, "r") as f:
code = f.read()
docstrings = code.split('"""')