From d0c02398b986f2876b2b79f3a137ed00a7edde35 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Thu, 1 May 2025 12:17:52 -0400 Subject: [PATCH] cache packages_distributions (#11453) * cache packages_distributions * remove unused exception reference * make style Signed-off-by: Vladimir Mandic * change name to _package_map --------- Signed-off-by: Vladimir Mandic Co-authored-by: DN6 --- src/diffusers/utils/import_utils.py | 46 ++++++++++++++--------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2e055d85fd..406f1d999d 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -36,7 +36,10 @@ if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata - +try: + _package_map = importlib_metadata.packages_distributions() # load-once to avoid expensive calls +except Exception: + _package_map = None logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -56,35 +59,32 @@ _is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") f def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]: + global _package_map pkg_exists = importlib.util.find_spec(pkg_name) is not None pkg_version = "N/A" if pkg_exists: + if _package_map is None: + _package_map = defaultdict(list) + try: + # Fallback for Python < 3.10 + for dist in importlib_metadata.distributions(): + _top_level_declared = (dist.read_text("top_level.txt") or "").split() + _infered_opt_names = { + f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or []) + } - {None} + _top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names) + for pkg in _top_level_declared or _top_level_inferred: + _package_map[pkg].append(dist.metadata["Name"]) + except Exception as _: + pass try: - package_map = importlib_metadata.packages_distributions() - except Exception as e: - package_map = defaultdict(list) - if isinstance(e, AttributeError): - try: - # Fallback for Python < 3.10 - for dist in importlib_metadata.distributions(): - _top_level_declared = (dist.read_text("top_level.txt") or "").split() - _infered_opt_names = { - f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or []) - } - {None} - _top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names) - for pkg in _top_level_declared or _top_level_inferred: - package_map[pkg].append(dist.metadata["Name"]) - except Exception as _: - pass - - try: - if get_dist_name and pkg_name in package_map and package_map[pkg_name]: - if len(package_map[pkg_name]) > 1: + if get_dist_name and pkg_name in _package_map and _package_map[pkg_name]: + if len(_package_map[pkg_name]) > 1: logger.warning( - f"Multiple distributions found for package {pkg_name}. Picked distribution: {package_map[pkg_name][0]}" + f"Multiple distributions found for package {pkg_name}. Picked distribution: {_package_map[pkg_name][0]}" ) - pkg_name = package_map[pkg_name][0] + pkg_name = _package_map[pkg_name][0] pkg_version = importlib_metadata.version(pkg_name) logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") except (ImportError, importlib_metadata.PackageNotFoundError):