1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[perf] Cache version checks

I recently noticed that we are spending a non-negligible amount of time in `version.parse` when running pipelines (approx. ~50ms per step for the QwenImageEdit pipeline on a ZeroGPU Space for instance, which in this case represents almost 10% of the actual compute). The calls to those version checks originate from:
- 4588bbeb42/src/diffusers/hooks/hooks.py (L277)

Maybe that the issue can otherwise be solved from root (why do we need to unwrap the modules at each call?) or maybe that my particular setup triggered this? (I patched the forward method at the blocks level but I don't feel like it has an incidence over _set_context)
This commit is contained in:
Charles
2025-09-26 17:28:55 +02:00
committed by GitHub
parent 4588bbeb42
commit 2ca3cadb35

View File

@@ -21,6 +21,7 @@ import operator as op
import os
import sys
from collections import OrderedDict, defaultdict
from functools import cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version)
@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version)
@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version)
@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)
@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)
@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version)
@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
@cache
def is_nvidia_modelopt_version(operation: str, version: str):
"""
Compares the current Nvidia ModelOpt version to a given reference with an operation.
@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
@cache
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version)
@cache
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version)
@cache
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.