From 913986afa53ace2b0becc20535ef7c32cb15276a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Nov 2023 12:05:26 +0100 Subject: [PATCH] Improve setup.py and add dependency check (#5826) * put peft in requirements * correct peft * correct installs * make style * make style --- setup.py | 2 + src/diffusers/dependency_versions_check.py | 15 +-- src/diffusers/dependency_versions_table.py | 2 + src/diffusers/utils/constants.py | 7 +- src/diffusers/utils/versions.py | 117 +++++++++++++++++++++ 5 files changed, 128 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/utils/versions.py diff --git a/setup.py b/setup.py index 9bed326b44..0d048b630f 100644 --- a/setup.py +++ b/setup.py @@ -113,10 +113,12 @@ _deps = [ "numpy", "omegaconf", "parameterized", + "peft<=0.6.2", "protobuf>=3.20.3,<4", "pytest", "pytest-timeout", "pytest-xdist", + "python>=3.8.0", "ruff==0.0.280", "safetensors>=0.3.1", "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/diffusers/dependency_versions_check.py b/src/diffusers/dependency_versions_check.py index 4f8578c529..0144db201a 100644 --- a/src/diffusers/dependency_versions_check.py +++ b/src/diffusers/dependency_versions_check.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys from .dependency_versions_table import deps from .utils.versions import require_version, require_version_core @@ -23,21 +22,9 @@ from .utils.versions import require_version, require_version_core # order specific notes: # - tqdm must be checked before tokenizers -pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split() -if sys.version_info < (3, 7): - pkgs_to_check_at_runtime.append("dataclasses") -if sys.version_info < (3, 8): - pkgs_to_check_at_runtime.append("importlib_metadata") - +pkgs_to_check_at_runtime = "python requests filelock numpy".split() for pkg in pkgs_to_check_at_runtime: if pkg in deps: - if pkg == "tokenizers": - # must be loaded here, or else tqdm check may fail - from .utils import is_tokenizers_available - - if not is_tokenizers_available(): - continue # not required, check version only if installed - require_version_core(deps[pkg]) else: raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py") diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index b047064760..143e706ef7 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -25,10 +25,12 @@ deps = { "numpy": "numpy", "omegaconf": "omegaconf", "parameterized": "parameterized", + "peft": "peft<=0.6.2", "protobuf": "protobuf>=3.20.3,<4", "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", + "python": "python>=3.8.0", "ruff": "ruff==0.0.280", "safetensors": "safetensors>=0.3.1", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 8ae5b0dec4..608a751fb8 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -17,13 +17,15 @@ import os from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home from packaging import version -from .import_utils import is_peft_available, is_transformers_available +from ..dependency_versions_check import dep_version_check +from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available default_cache_path = HUGGINGFACE_HUB_CACHE MIN_PEFT_VERSION = "0.6.0" MIN_TRANSFORMERS_VERSION = "4.34.0" +_CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES CONFIG_NAME = "config.json" @@ -50,3 +52,6 @@ _required_transformers_version = is_transformers_available() and version.parse( ) >= version.parse(MIN_TRANSFORMERS_VERSION) USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version + +if USE_PEFT_BACKEND and _CHECK_PEFT: + dep_version_check("peft") diff --git a/src/diffusers/utils/versions.py b/src/diffusers/utils/versions.py new file mode 100644 index 0000000000..945a3977ce --- /dev/null +++ b/src/diffusers/utils/versions.py @@ -0,0 +1,117 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities for working with package versions +""" + +import importlib.metadata +import operator +import re +import sys +from typing import Optional + +from packaging import version + + +ops = { + "<": operator.lt, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + ">=": operator.ge, + ">": operator.gt, +} + + +def _compare_versions(op, got_ver, want_ver, requirement, pkg, hint): + if got_ver is None or want_ver is None: + raise ValueError( + f"Unable to compare versions for {requirement}: need={want_ver} found={got_ver}. This is unusual. Consider" + f" reinstalling {pkg}." + ) + if not ops[op](version.parse(got_ver), version.parse(want_ver)): + raise ImportError( + f"{requirement} is required for a normal functioning of this module, but found {pkg}=={got_ver}.{hint}" + ) + + +def require_version(requirement: str, hint: Optional[str] = None) -> None: + """ + Perform a runtime check of the dependency versions, using the exact same syntax used by pip. + + The installed module version comes from the *site-packages* dir via *importlib.metadata*. + + Args: + requirement (`str`): pip style definition, e.g., "tokenizers==0.9.4", "tqdm>=4.27", "numpy" + hint (`str`, *optional*): what suggestion to print in case of requirements not being met + + Example: + + ```python + require_version("pandas>1.1.2") + require_version("numpy>1.18.5", "this is important to have for whatever reason") + ```""" + + hint = f"\n{hint}" if hint is not None else "" + + # non-versioned check + if re.match(r"^[\w_\-\d]+$", requirement): + pkg, op, want_ver = requirement, None, None + else: + match = re.findall(r"^([^!=<>\s]+)([\s!=<>]{1,2}.+)", requirement) + if not match: + raise ValueError( + "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23, but" + f" got {requirement}" + ) + pkg, want_full = match[0] + want_range = want_full.split(",") # there could be multiple requirements + wanted = {} + for w in want_range: + match = re.findall(r"^([\s!=<>]{1,2})(.+)", w) + if not match: + raise ValueError( + "requirement needs to be in the pip package format, .e.g., package_a==1.23, or package_b>=1.23," + f" but got {requirement}" + ) + op, want_ver = match[0] + wanted[op] = want_ver + if op not in ops: + raise ValueError(f"{requirement}: need one of {list(ops.keys())}, but got {op}") + + # special case + if pkg == "python": + got_ver = ".".join([str(x) for x in sys.version_info[:3]]) + for op, want_ver in wanted.items(): + _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) + return + + # check if any version is installed + try: + got_ver = importlib.metadata.version(pkg) + except importlib.metadata.PackageNotFoundError: + raise importlib.metadata.PackageNotFoundError( + f"The '{requirement}' distribution was not found and is required by this application. {hint}" + ) + + # check that the right version is installed if version number or a range was provided + if want_ver is not None: + for op, want_ver in wanted.items(): + _compare_versions(op, got_ver, want_ver, requirement, pkg, hint) + + +def require_version_core(requirement): + """require_version wrapper which emits a core-specific hint on failure""" + hint = "Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git main" + return require_version(requirement, hint)