1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[perf] Cache version checks (#12399)

This commit is contained in:
Charles
2025-10-06 17:45:34 +02:00
committed by GitHub
parent 7f3e9b8695
commit cf4b97b233

View File

@@ -21,6 +21,7 @@ import operator as op
import os
import sys
from collections import OrderedDict, defaultdict
from functools import lru_cache as cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version)
@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version)
@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version)
@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)
@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)
@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version)
@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
@cache
def is_nvidia_modelopt_version(operation: str, version: str):
"""
Compares the current Nvidia ModelOpt version to a given reference with an operation.
@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
@cache
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version)
@cache
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version)
@cache
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.