mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix pipelines user_agent, ignore CI requests (#1058)
* Fix pipelines user_agent, ignore CI requests * fix circular import * N/A versions * N/A versions
This commit is contained in:
1
.github/workflows/pr_tests.yml
vendored
1
.github/workflows/pr_tests.yml
vendored
@@ -10,6 +10,7 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
1
.github/workflows/push_tests.yml
vendored
1
.github/workflows/push_tests.yml
vendored
@@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
|
||||
@@ -16,13 +16,25 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .utils import deprecate, is_modelcards_available, logging
|
||||
from . import __version__
|
||||
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
|
||||
from .utils.import_utils import (
|
||||
_flax_version,
|
||||
_jax_version,
|
||||
_onnxruntime_version,
|
||||
_torch_version,
|
||||
is_flax_available,
|
||||
is_modelcards_available,
|
||||
is_onnx_available,
|
||||
is_torch_available,
|
||||
)
|
||||
|
||||
|
||||
if is_modelcards_available():
|
||||
@@ -33,6 +45,32 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
|
||||
SESSION_ID = uuid4().hex
|
||||
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
|
||||
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
"""
|
||||
Formats a user-agent string with basic info about a request.
|
||||
"""
|
||||
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
|
||||
if DISABLE_TELEMETRY:
|
||||
return ua + "; telemetry/off"
|
||||
if is_torch_available():
|
||||
ua += f"; torch/{_torch_version}"
|
||||
if is_flax_available():
|
||||
ua += f"; jax/{_jax_version}"
|
||||
ua += f"; flax/{_flax_version}"
|
||||
if is_onnx_available():
|
||||
ua += f"; onnxruntime/{_onnxruntime_version}"
|
||||
# CI will set this value to True
|
||||
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
|
||||
ua += "; is_ci/true"
|
||||
if isinstance(user_agent, dict):
|
||||
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += "; " + user_agent
|
||||
return ua
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
@@ -101,7 +139,7 @@ def init_git_repo(args, at_init: bool = False):
|
||||
|
||||
def push_to_hub(
|
||||
args,
|
||||
pipeline: DiffusionPipeline,
|
||||
pipeline,
|
||||
repo: Repository,
|
||||
commit_message: Optional[str] = "End of training",
|
||||
blocking: bool = True,
|
||||
|
||||
@@ -29,6 +29,7 @@ from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .hub_utils import http_user_agent
|
||||
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
||||
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
|
||||
@@ -301,6 +302,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
allow_patterns = [os.path.join(k, "*") for k in folder_names]
|
||||
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
|
||||
|
||||
if cls != FlaxDiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
else:
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -311,6 +319,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
@@ -30,9 +30,9 @@ from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import __version__
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import http_user_agent
|
||||
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from .utils import (
|
||||
CONFIG_NAME,
|
||||
@@ -398,10 +398,14 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
|
||||
if cls != DiffusionPipeline:
|
||||
requested_pipeline_class = cls.__name__
|
||||
else:
|
||||
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
|
||||
user_agent = {"pipeline_class": requested_pipeline_class}
|
||||
if custom_pipeline is not None:
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
|
||||
@@ -90,7 +90,8 @@ else:
|
||||
logger.info("Disabling Tensorflow because USE_TORCH is set")
|
||||
_tf_available = False
|
||||
|
||||
|
||||
_jax_version = "N/A"
|
||||
_flax_version = "N/A"
|
||||
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
|
||||
if _flax_available:
|
||||
@@ -136,6 +137,7 @@ except importlib_metadata.PackageNotFoundError:
|
||||
_modelcards_available = False
|
||||
|
||||
|
||||
_onnxruntime_version = "N/A"
|
||||
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
if _onnx_available:
|
||||
candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")
|
||||
|
||||
Reference in New Issue
Block a user