mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Pipeline utils] feat: implement push_to_hub for standalone models, schedulers as well as pipelines (#4128)
* feat: implement push_to_hub for standalone models. * address PR feedback. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * remove max_shard_size. * add: support for scheduler push_to_hub * enable push_to_hub support for flax schedulers. * enable push_to_hub for pipelines. * Apply suggestions from code review Co-authored-by: Lucain <lucainp@gmail.com> * reflect pr feedback. * address another round of deedback. * better handling of kwargs. * add: tests * Apply suggestions from code review Co-authored-by: Lucain <lucainp@gmail.com> * setting hub staging to False for now. * incorporate staging test as a separate job. Co-authored-by: ydshieh <2521628+ydshieh@users.noreply.github.com> * fix: tokenizer loading. * fix: json dumping. * move is_staging_test to a better location. * better treatment to tokens. * define repo_id to better handle concurrency * style * explicitly set token * Empty-Commit * move SUER, TOKEN to test * collate org_repo_id * delete repo --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: ydshieh <2521628+ydshieh@users.noreply.github.com>
This commit is contained in:
57
.github/workflows/pr_tests.yml
vendored
57
.github/workflows/pr_tests.yml
vendored
@@ -113,3 +113,60 @@ jobs:
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
run_staging_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Hub tests for models, schedulers, and pipelines
|
||||
framework: hub_tests_pytorch
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_hub
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
|
||||
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
|
||||
run: |
|
||||
HUGGINGFACE_CO_STAGING=true python -m pytest \
|
||||
-m "is_staging_test" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -9,4 +9,8 @@ All models are built from the base [`ModelMixin`] class which is a [`torch.nn.mo
|
||||
|
||||
## FlaxModelMixin
|
||||
|
||||
[[autodoc]] FlaxModelMixin
|
||||
[[autodoc]] FlaxModelMixin
|
||||
|
||||
## Pushing to the Hub
|
||||
|
||||
[[autodoc]] utils.PushToHubMixin
|
||||
@@ -26,7 +26,7 @@ from pathlib import PosixPath
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
@@ -144,6 +144,12 @@ class ConfigMixin:
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
@@ -156,6 +162,22 @@ class ConfigMixin:
|
||||
self.to_json_file(output_config_file)
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
||||
r"""
|
||||
|
||||
@@ -23,7 +23,7 @@ import msgpack.exceptions
|
||||
from flax.core.frozen_dict import FrozenDict, unfreeze
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.traverse_util import flatten_dict, unflatten_dict
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
|
||||
@@ -34,6 +34,7 @@ from ..utils import (
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
WEIGHTS_NAME,
|
||||
PushToHubMixin,
|
||||
logging,
|
||||
)
|
||||
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
@@ -42,7 +43,7 @@ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FlaxModelMixin:
|
||||
class FlaxModelMixin(PushToHubMixin):
|
||||
r"""
|
||||
Base class for all Flax models.
|
||||
|
||||
@@ -497,6 +498,8 @@ class FlaxModelMixin:
|
||||
save_directory: Union[str, os.PathLike],
|
||||
params: Union[Dict, FrozenDict],
|
||||
is_main_process: bool = True,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory so that it can be reloaded using the
|
||||
@@ -511,6 +514,12 @@ class FlaxModelMixin:
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
@@ -518,6 +527,14 @@ class FlaxModelMixin:
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
model_to_save = self
|
||||
|
||||
# Attach architecture to the config
|
||||
@@ -532,3 +549,12 @@ class FlaxModelMixin:
|
||||
f.write(model_bytes)
|
||||
|
||||
logger.info(f"Model weights saved in {output_model_file}")
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import create_repo
|
||||
from torch import Tensor, device, nn
|
||||
|
||||
from .. import __version__
|
||||
@@ -40,6 +41,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import PushToHubMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -147,7 +149,7 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
return error_msgs
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module):
|
||||
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
@@ -272,6 +274,8 @@ class ModelMixin(torch.nn.Module):
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory so that it can be reloaded using the
|
||||
@@ -292,6 +296,12 @@ class ModelMixin(torch.nn.Module):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
@@ -299,6 +309,15 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = self
|
||||
|
||||
# Attach architecture to the config
|
||||
@@ -322,6 +341,15 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
|
||||
@@ -23,14 +23,22 @@ import flax
|
||||
import numpy as np
|
||||
import PIL
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import create_repo, snapshot_download
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
|
||||
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
|
||||
from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
BaseOutput,
|
||||
PushToHubMixin,
|
||||
http_user_agent,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -90,7 +98,7 @@ class FlaxImagePipelineOutput(BaseOutput):
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
class FlaxDiffusionPipeline(ConfigMixin):
|
||||
class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
r"""
|
||||
Base class for Flax-based pipelines.
|
||||
|
||||
@@ -139,7 +147,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
params: Union[Dict, FrozenDict],
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO: handle inference_state
|
||||
"""
|
||||
Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
|
||||
@@ -149,6 +163,12 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
@@ -157,6 +177,14 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
@@ -188,6 +216,15 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
else:
|
||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
|
||||
@@ -28,7 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from huggingface_hub import ModelCard, hf_hub_download, model_info, snapshot_download
|
||||
from huggingface_hub import ModelCard, create_repo, hf_hub_download, model_info, snapshot_download
|
||||
from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
from tqdm.auto import tqdm
|
||||
@@ -66,7 +66,7 @@ if is_transformers_available():
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
|
||||
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -472,7 +472,7 @@ def load_sub_model(
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all pipelines.
|
||||
|
||||
@@ -558,6 +558,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its
|
||||
@@ -571,6 +573,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
model_index_dict = dict(self.config)
|
||||
model_index_dict.pop("_class_name", None)
|
||||
@@ -578,6 +586,14 @@ class DiffusionPipeline(ConfigMixin):
|
||||
model_index_dict.pop("_module", None)
|
||||
model_index_dict.pop("_name_or_path", None)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||
|
||||
def is_saveable_module(name, value):
|
||||
@@ -641,6 +657,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# finally save the config
|
||||
self.save_config(save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
def to(
|
||||
self,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, PushToHubMixin
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
@@ -60,7 +60,7 @@ class SchedulerOutput(BaseOutput):
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class SchedulerMixin:
|
||||
class SchedulerMixin(PushToHubMixin):
|
||||
"""
|
||||
Base class for all schedulers.
|
||||
|
||||
@@ -153,7 +153,13 @@ class SchedulerMixin:
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save a configuration JSON file to. Will be created if it doesn't exist.
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, PushToHubMixin
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
@@ -53,7 +53,7 @@ class FlaxSchedulerOutput(BaseOutput):
|
||||
prev_sample: jnp.ndarray
|
||||
|
||||
|
||||
class FlaxSchedulerMixin:
|
||||
class FlaxSchedulerMixin(PushToHubMixin):
|
||||
"""
|
||||
Mixin containing common functions for the schedulers.
|
||||
|
||||
@@ -156,6 +156,12 @@ class FlaxSchedulerMixin:
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from .doc_utils import replace_example_docstring
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import (
|
||||
HF_HUB_OFFLINE,
|
||||
PushToHubMixin,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
extract_commit_hash,
|
||||
|
||||
@@ -17,13 +17,22 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
|
||||
from huggingface_hub import (
|
||||
HfFolder,
|
||||
ModelCard,
|
||||
ModelCardData,
|
||||
create_repo,
|
||||
hf_hub_download,
|
||||
upload_folder,
|
||||
whoami,
|
||||
)
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
@@ -359,3 +368,96 @@ def _get_model_file(
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {weights_name}"
|
||||
)
|
||||
|
||||
|
||||
class PushToHubMixin:
|
||||
"""
|
||||
A Mixin containing the functionality to push a model/scheduler to the Hugging Face Hub.
|
||||
"""
|
||||
|
||||
def _upload_folder(
|
||||
self,
|
||||
working_dir: Union[str, os.PathLike],
|
||||
repo_id: str,
|
||||
token: Optional[str] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
create_pr: bool = False,
|
||||
):
|
||||
"""
|
||||
Uploads all files in `working_dir` to `repo_id`.
|
||||
"""
|
||||
if commit_message is None:
|
||||
if "Model" in self.__class__.__name__:
|
||||
commit_message = "Upload model"
|
||||
elif "Scheduler" in self.__class__.__name__:
|
||||
commit_message = "Upload scheduler"
|
||||
else:
|
||||
commit_message = f"Upload {self.__class__.__name__}"
|
||||
|
||||
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
|
||||
return upload_folder(
|
||||
repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr
|
||||
)
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
repo_id: str,
|
||||
commit_message: Optional[str] = None,
|
||||
private: Optional[bool] = None,
|
||||
token: Optional[str] = None,
|
||||
create_pr: bool = False,
|
||||
safe_serialization: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Upload the {object_files} to the 🤗 Hugging Face Hub.
|
||||
|
||||
Parameters:
|
||||
repo_id (`str`):
|
||||
The name of the repository you want to push your {object} to. It should contain your organization name
|
||||
when pushing to a given organization. `repo_id` can also be a path to a local directory.
|
||||
commit_message (`str`, *optional*):
|
||||
Message to commit while pushing. Will default to `"Upload {object}"`.
|
||||
private (`bool`, *optional*):
|
||||
Whether or not the repository created should be private.
|
||||
token (`str`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. The token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
create_pr (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to create a PR with the uploaded files or directly commit.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to convert the model weights in safetensors format for safer serialization.
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet")
|
||||
|
||||
# Push the `unet` to your namespace with the name "my-finetuned-unet".
|
||||
unet.push_to_hub("my-finetuned-unet")
|
||||
|
||||
# Push the {object} to an organization with the name "my-finetuned-unet".
|
||||
unet.push_to_hub("your-org/my-finetuned-unet")
|
||||
```
|
||||
"""
|
||||
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
|
||||
|
||||
# Save all files.
|
||||
save_kwargs = {"safe_serialization": safe_serialization}
|
||||
if "Scheduler" not in self.__class__.__name__:
|
||||
save_kwargs.update({"variant": variant})
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
self.save_pretrained(tmpdir, **save_kwargs)
|
||||
|
||||
return self._upload_folder(
|
||||
tmpdir,
|
||||
repo_id,
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
)
|
||||
|
||||
@@ -18,18 +18,27 @@ import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import uuid
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
from huggingface_hub import delete_repo
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import logging, torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, require_torch_gpu, run_test_in_subprocess
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
run_test_in_subprocess,
|
||||
)
|
||||
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
@@ -563,3 +572,72 @@ class ModelTesterMixin:
|
||||
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ModelPushToHubTester(unittest.TestCase):
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-model-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def test_push_to_hub(self):
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
model.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id)
|
||||
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
@@ -13,12 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from distutils.util import strtobool
|
||||
|
||||
import pytest
|
||||
|
||||
from diffusers import __version__
|
||||
from diffusers.utils import deprecate
|
||||
|
||||
|
||||
# Used to test the hub
|
||||
USER = "__DUMMY_TRANSFORMERS_USER__"
|
||||
ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
|
||||
|
||||
# Not critical, only usable on the sandboxed CI instance.
|
||||
TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
|
||||
|
||||
|
||||
class DeprecateTester(unittest.TestCase):
|
||||
higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:])
|
||||
lower_version = "0.0.1"
|
||||
@@ -168,3 +180,34 @@ class DeprecateTester(unittest.TestCase):
|
||||
deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False)
|
||||
assert str(warning.warning) == "This message is better!!!"
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
try:
|
||||
value = os.environ[key]
|
||||
except KeyError:
|
||||
# KEY isn't set, default to `default`.
|
||||
_value = default
|
||||
else:
|
||||
# KEY is set, convert it to True or False.
|
||||
try:
|
||||
_value = strtobool(value)
|
||||
except ValueError:
|
||||
# More values are supported, but let's keep the message simple.
|
||||
raise ValueError(f"If set, {key} must be yes or no.")
|
||||
return _value
|
||||
|
||||
|
||||
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
|
||||
|
||||
|
||||
def is_staging_test(test_case):
|
||||
"""
|
||||
Decorator marking a test as a staging test.
|
||||
|
||||
Those tests will run using the staging environment of huggingface.co instead of the real model hub.
|
||||
"""
|
||||
if not _run_staging:
|
||||
return unittest.skip("test is staging test")(test_case)
|
||||
else:
|
||||
return pytest.mark.is_staging_test()(test_case)
|
||||
|
||||
@@ -2,23 +2,30 @@ import contextlib
|
||||
import gc
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Callable, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from huggingface_hub import delete_repo
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
|
||||
from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device
|
||||
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
def to_np(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
@@ -795,6 +802,126 @@ class PipelineTesterMixin:
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class PipelinePushToHubTester(unittest.TestCase):
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-pipeline-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def get_pipeline_components(self):
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dummy_vocab = {"<|startoftext|>": 0, "<|endoftext|>": 1, "!": 2}
|
||||
vocab_path = os.path.join(tmpdir, "vocab.json")
|
||||
with open(vocab_path, "w") as f:
|
||||
json.dump(dummy_vocab, f)
|
||||
|
||||
merges = "Ġ t\nĠt h"
|
||||
merges_path = os.path.join(tmpdir, "merges.txt")
|
||||
with open(merges_path, "w") as f:
|
||||
f.writelines(merges)
|
||||
tokenizer = CLIPTokenizer(vocab_file=vocab_path, merges_file=merges_path)
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def test_push_to_hub(self):
|
||||
components = self.get_pipeline_components()
|
||||
pipeline = StableDiffusionPipeline(**components)
|
||||
pipeline.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}", subfolder="unet")
|
||||
unet = components["unet"]
|
||||
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}", subfolder="unet")
|
||||
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
components = self.get_pipeline_components()
|
||||
pipeline = StableDiffusionPipeline(**components)
|
||||
pipeline.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id, subfolder="unet")
|
||||
unet = components["unet"]
|
||||
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id, subfolder="unet")
|
||||
for p1, p2 in zip(unet.parameters(), new_model.parameters()):
|
||||
self.assertTrue(torch.equal(p1, p2))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
@@ -17,10 +17,12 @@ import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import delete_repo
|
||||
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
@@ -41,6 +43,8 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import torch_device
|
||||
from diffusers.utils.testing_utils import CaptureLogger
|
||||
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
@@ -720,3 +724,64 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler.does_not_exist
|
||||
|
||||
assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class SchedulerPushToHubTester(unittest.TestCase):
|
||||
identifier = uuid.uuid4()
|
||||
repo_id = f"test-scheduler-{identifier}"
|
||||
org_repo_id = f"valid_org/{repo_id}-org"
|
||||
|
||||
def test_push_to_hub(self):
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
scheduler.push_to_hub(self.repo_id, token=TOKEN)
|
||||
scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
|
||||
assert type(scheduler) == type(scheduler_loaded)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
# Push to hub via save_config
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
scheduler.save_config(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}")
|
||||
|
||||
assert type(scheduler) == type(scheduler_loaded)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.repo_id)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
scheduler.push_to_hub(self.org_repo_id, token=TOKEN)
|
||||
scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id)
|
||||
|
||||
assert type(scheduler) == type(scheduler_loaded)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
# Push to hub via save_config
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
scheduler.save_config(tmp_dir, repo_id=self.org_repo_id, push_to_hub=True, token=TOKEN)
|
||||
|
||||
scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id)
|
||||
|
||||
assert type(scheduler) == type(scheduler_loaded)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=TOKEN, repo_id=self.org_repo_id)
|
||||
|
||||
Reference in New Issue
Block a user