mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add is_torch_available, is_flax_available (#204)
* Add is_<framework>_available, refactor import utils * deps * quality
This commit is contained in:
@@ -22,14 +22,13 @@ from collections import OrderedDict
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.utils import ENV_VARS_TRUE_VALUES
|
||||
from diffusers.models.auto import get_values
|
||||
from diffusers.utils import ENV_VARS_TRUE_VALUES, is_flax_available, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
# python utils/check_repo.py
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
PATH_TO_DIFFUSERS = "src/diffusers"
|
||||
PATH_TO_TESTS = "tests"
|
||||
PATH_TO_DOC = "docs/source/en"
|
||||
|
||||
@@ -200,17 +199,17 @@ MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
||||
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"transformers",
|
||||
os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_TRANSFORMERS],
|
||||
"diffusers",
|
||||
os.path.join(PATH_TO_DIFFUSERS, "__init__.py"),
|
||||
submodule_search_locations=[PATH_TO_DIFFUSERS],
|
||||
)
|
||||
transformers = spec.loader.load_module()
|
||||
diffusers = spec.loader.load_module()
|
||||
|
||||
|
||||
def check_model_list():
|
||||
"""Check the model list inside the transformers library."""
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
||||
# Get the models from the directory structure of `src/diffusers/models/`
|
||||
models_dir = os.path.join(PATH_TO_DIFFUSERS, "models")
|
||||
_models = []
|
||||
for model in os.listdir(models_dir):
|
||||
model_dir = os.path.join(models_dir, model)
|
||||
@@ -218,7 +217,7 @@ def check_model_list():
|
||||
_models.append(model)
|
||||
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
||||
models = [model for model in dir(diffusers.models) if not model.startswith("__")]
|
||||
|
||||
missing_models = sorted(list(set(_models).difference(models)))
|
||||
if missing_models:
|
||||
@@ -256,10 +255,10 @@ def get_model_modules():
|
||||
"modeling_vision_encoder_decoder",
|
||||
]
|
||||
modules = []
|
||||
for model in dir(transformers.models):
|
||||
for model in dir(diffusers.models):
|
||||
# There are some magic dunder attributes in the dir, we ignore them
|
||||
if not model.startswith("__"):
|
||||
model_module = getattr(transformers.models, model)
|
||||
model_module = getattr(diffusers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
@@ -271,7 +270,7 @@ def get_model_modules():
|
||||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
models = []
|
||||
model_classes = (transformers.ModelMixin, transformers.TFModelMixin, transformers.FlaxModelMixin)
|
||||
model_classes = (diffusers.ModelMixin, diffusers.TFModelMixin, diffusers.FlaxModelMixin)
|
||||
for attr_name in dir(module):
|
||||
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||
continue
|
||||
@@ -299,7 +298,7 @@ def is_a_private_model(model):
|
||||
def check_models_are_in_init():
|
||||
"""Checks all models defined in the library are in the main init."""
|
||||
models_not_in_init = []
|
||||
dir_transformers = dir(transformers)
|
||||
dir_transformers = dir(diffusers)
|
||||
for module in get_model_modules():
|
||||
models_not_in_init += [
|
||||
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
|
||||
@@ -419,17 +418,17 @@ def get_all_auto_configured_models():
|
||||
"""Return the list of all models in at least one auto class."""
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
if is_torch_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_auto):
|
||||
for attr_name in dir(diffusers.models.auto.modeling_auto):
|
||||
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
|
||||
result = result | set(get_values(getattr(diffusers.models.auto.modeling_auto, attr_name)))
|
||||
if is_tf_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_tf_auto):
|
||||
for attr_name in dir(diffusers.models.auto.modeling_tf_auto):
|
||||
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
|
||||
result = result | set(get_values(getattr(diffusers.models.auto.modeling_tf_auto, attr_name)))
|
||||
if is_flax_available():
|
||||
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
|
||||
for attr_name in dir(diffusers.models.auto.modeling_flax_auto):
|
||||
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING_NAMES"):
|
||||
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
|
||||
result = result | set(get_values(getattr(diffusers.models.auto.modeling_flax_auto, attr_name)))
|
||||
return [cls for cls in result]
|
||||
|
||||
|
||||
@@ -636,8 +635,8 @@ def ignore_undocumented(name):
|
||||
):
|
||||
return True
|
||||
# Submodules are not documented.
|
||||
if os.path.isdir(os.path.join(PATH_TO_TRANSFORMERS, name)) or os.path.isfile(
|
||||
os.path.join(PATH_TO_TRANSFORMERS, f"{name}.py")
|
||||
if os.path.isdir(os.path.join(PATH_TO_DIFFUSERS, name)) or os.path.isfile(
|
||||
os.path.join(PATH_TO_DIFFUSERS, f"{name}.py")
|
||||
):
|
||||
return True
|
||||
# All load functions are not documented.
|
||||
@@ -660,8 +659,8 @@ def ignore_undocumented(name):
|
||||
def check_all_objects_are_documented():
|
||||
"""Check all models are properly documented."""
|
||||
documented_objs = find_all_documented_objects()
|
||||
modules = transformers._modules
|
||||
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||
modules = diffusers._modules
|
||||
objects = [c for c in dir(diffusers) if c not in modules and not c.startswith("_")]
|
||||
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
||||
if len(undocumented_objs) > 0:
|
||||
raise Exception(
|
||||
@@ -677,7 +676,7 @@ def check_model_type_doc_match():
|
||||
model_doc_folder = Path(PATH_TO_DOC) / "model_doc"
|
||||
model_docs = [m.stem for m in model_doc_folder.glob("*.mdx")]
|
||||
|
||||
model_types = list(transformers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
||||
model_types = list(diffusers.models.auto.configuration_auto.MODEL_NAMES_MAPPING.keys())
|
||||
model_types = [MODEL_TYPE_TO_DOC_MAPPING[m] if m in MODEL_TYPE_TO_DOC_MAPPING else m for m in model_types]
|
||||
|
||||
errors = []
|
||||
@@ -723,7 +722,7 @@ def is_rst_docstring(docstring):
|
||||
def check_docstrings_are_in_md():
|
||||
"""Check all docstrings are in md"""
|
||||
files_with_rst = []
|
||||
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
|
||||
for file in Path(PATH_TO_DIFFUSERS).glob("**/*.py"):
|
||||
with open(file, "r") as f:
|
||||
code = f.read()
|
||||
docstrings = code.split('"""')
|
||||
|
||||
Reference in New Issue
Block a user