1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix pipelines user_agent, ignore CI requests (#1058)

* Fix pipelines user_agent, ignore CI requests

* fix circular import

* N/A versions

* N/A versions
This commit is contained in:
Anton Lozhkov
2022-10-31 13:38:43 +01:00
committed by GitHub
parent 82d56cf192
commit 1606eb994a
6 changed files with 63 additions and 8 deletions

View File

@@ -10,6 +10,7 @@ concurrency:
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 60

View File

@@ -6,6 +6,7 @@ on:
- main
env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8

View File

@@ -16,13 +16,25 @@
import os
import shutil
import sys
from pathlib import Path
from typing import Optional
from typing import Dict, Optional, Union
from uuid import uuid4
from huggingface_hub import HfFolder, Repository, whoami
from .pipeline_utils import DiffusionPipeline
from .utils import deprecate, is_modelcards_available, logging
from . import __version__
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
from .utils.import_utils import (
_flax_version,
_jax_version,
_onnxruntime_version,
_torch_version,
is_flax_available,
is_modelcards_available,
is_onnx_available,
is_torch_available,
)
if is_modelcards_available():
@@ -33,6 +45,32 @@ logger = logging.get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
if DISABLE_TELEMETRY:
return ua + "; telemetry/off"
if is_torch_available():
ua += f"; torch/{_torch_version}"
if is_flax_available():
ua += f"; jax/{_jax_version}"
ua += f"; flax/{_flax_version}"
if is_onnx_available():
ua += f"; onnxruntime/{_onnxruntime_version}"
# CI will set this value to True
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
ua += "; is_ci/true"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return ua
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
@@ -101,7 +139,7 @@ def init_git_repo(args, at_init: bool = False):
def push_to_hub(
args,
pipeline: DiffusionPipeline,
pipeline,
repo: Repository,
commit_message: Optional[str] = "End of training",
blocking: bool = True,

View File

@@ -29,6 +29,7 @@ from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .hub_utils import http_user_agent
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
@@ -301,6 +302,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
user_agent = http_user_agent(user_agent)
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
@@ -311,6 +319,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
user_agent=user_agent,
)
else:
cached_folder = pretrained_model_name_or_path

View File

@@ -30,9 +30,9 @@ from packaging import version
from PIL import Image
from tqdm.auto import tqdm
from . import __version__
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import (
CONFIG_NAME,
@@ -398,10 +398,14 @@ class DiffusionPipeline(ConfigMixin):
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
if custom_pipeline is not None:
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
# download all allow_patterns
cached_folder = snapshot_download(

View File

@@ -90,7 +90,8 @@ else:
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
@@ -136,6 +137,7 @@ except importlib_metadata.PackageNotFoundError:
_modelcards_available = False
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")