diff --git a/setup.py b/setup.py index 7b4d25fd6c..16495da76f 100644 --- a/setup.py +++ b/setup.py @@ -161,10 +161,11 @@ extras = {} extras = {} extras["quality"] = ["black ~= 22.0", "isort >= 5.5.4", "flake8 >= 3.8.3"] extras["docs"] = [] +extras["training"] = ["tensorboard", "modelcards"] extras["test"] = [ "pytest", ] -extras["dev"] = extras["quality"] + extras["test"] +extras["dev"] = extras["quality"] + extras["test"] + extras["training"] install_requires = [ deps["filelock"], @@ -174,8 +175,6 @@ install_requires = [ deps["requests"], deps["torch"], deps["Pillow"], - deps["tensorboard"], - deps["modelcards"], ] setup( diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index 4aa6180ca8..ae14ba354e 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -21,7 +21,11 @@ from typing import Optional from diffusers import DiffusionPipeline from huggingface_hub import HfFolder, Repository, whoami -from modelcards import CardData, ModelCard +from utils import is_modelcards_available + + +if is_modelcards_available(): + from modelcards import CardData, ModelCard from .utils import logging @@ -147,6 +151,12 @@ def push_to_hub( def create_model_card(args, model_name): + if not is_modelcards_available: + raise ValueError( + "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can" + " install the package with `pip install modelcards`." + ) + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2c56ba4a8a..c063cfaeb1 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -61,6 +61,14 @@ except importlib_metadata.PackageNotFoundError: _unidecode_available = False +_modelcards_available = importlib.util.find_spec("modelcards") is not None +try: + _modelcards_version = importlib_metadata.version("modelcards") + logger.debug(f"Successfully imported modelcards version {_modelcards_version}") +except importlib_metadata.PackageNotFoundError: + _modelcards_available = False + + def is_transformers_available(): return _transformers_available @@ -73,6 +81,10 @@ def is_unidecode_available(): return _unidecode_available +def is_modelcards_available(): + return _modelcards_available + + class RepositoryNotFoundError(HTTPError): """ Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does