diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 1c52428310..defd418edc 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -113,3 +113,60 @@ jobs: with: name: pr_${{ matrix.config.report }}_test_reports path: reports + + run_staging_tests: + strategy: + fail-fast: false + matrix: + config: + - name: Hub tests for models, schedulers, and pipelines + framework: hub_tests_pytorch + runner: docker-cpu + image: diffusers/diffusers-pytorch-cpu + report: torch_hub + + name: ${{ matrix.config.name }} + + runs-on: ${{ matrix.config.runner }} + + container: + image: ${{ matrix.config.image }} + options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ + + defaults: + run: + shell: bash + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Install dependencies + run: | + apt-get update && apt-get install libsndfile1-dev libgl1 -y + python -m pip install -e .[quality,test] + + - name: Environment + run: | + python utils/print_env.py + + - name: Run Hub tests for models, schedulers, and pipelines on a staging env + if: ${{ matrix.config.framework == 'hub_tests_pytorch' }} + run: | + HUGGINGFACE_CO_STAGING=true python -m pytest \ + -m "is_staging_test" \ + --make-reports=tests_${{ matrix.config.report }} \ + tests + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: pr_${{ matrix.config.report }}_test_reports + path: reports \ No newline at end of file diff --git a/docs/source/en/api/models/overview.md b/docs/source/en/api/models/overview.md index cc94861fba..b4e2d338e9 100644 --- a/docs/source/en/api/models/overview.md +++ b/docs/source/en/api/models/overview.md @@ -9,4 +9,8 @@ All models are built from the base [`ModelMixin`] class which is a [`torch.nn.mo ## FlaxModelMixin -[[autodoc]] FlaxModelMixin \ No newline at end of file +[[autodoc]] FlaxModelMixin + +## Pushing to the Hub + +[[autodoc]] utils.PushToHubMixin \ No newline at end of file diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f5c8e8919c..9bc25155a0 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -26,7 +26,7 @@ from pathlib import PosixPath from typing import Any, Dict, Tuple, Union import numpy as np -from huggingface_hub import hf_hub_download +from huggingface_hub import create_repo, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError @@ -144,6 +144,12 @@ class ConfigMixin: Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file is saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") @@ -156,6 +162,22 @@ class ConfigMixin: self.to_json_file(output_config_file) logger.info(f"Configuration saved in {output_config_file}") + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + @classmethod def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py index 9a6e1b3bba..4e4cebebe2 100644 --- a/src/diffusers/models/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -23,7 +23,7 @@ import msgpack.exceptions from flax.core.frozen_dict import FrozenDict, unfreeze from flax.serialization import from_bytes, to_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from huggingface_hub import hf_hub_download +from huggingface_hub import create_repo, hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError @@ -34,6 +34,7 @@ from ..utils import ( FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, + PushToHubMixin, logging, ) from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax @@ -42,7 +43,7 @@ from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax logger = logging.get_logger(__name__) -class FlaxModelMixin: +class FlaxModelMixin(PushToHubMixin): r""" Base class for all Flax models. @@ -497,6 +498,8 @@ class FlaxModelMixin: save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict], is_main_process: bool = True, + push_to_hub: bool = False, + **kwargs, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the @@ -511,6 +514,12 @@ class FlaxModelMixin: Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") @@ -518,6 +527,14 @@ class FlaxModelMixin: os.makedirs(save_directory, exist_ok=True) + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + model_to_save = self # Attach architecture to the config @@ -532,3 +549,12 @@ class FlaxModelMixin: f.write(model_bytes) logger.info(f"Model weights saved in {output_model_file}") + + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e1bb1a94ba..b575c9cdb2 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -23,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import safetensors import torch +from huggingface_hub import create_repo from torch import Tensor, device, nn from .. import __version__ @@ -40,6 +41,7 @@ from ..utils import ( is_torch_version, logging, ) +from ..utils.hub_utils import PushToHubMixin logger = logging.get_logger(__name__) @@ -147,7 +149,7 @@ def _load_state_dict_into_model(model_to_load, state_dict): return error_msgs -class ModelMixin(torch.nn.Module): +class ModelMixin(torch.nn.Module, PushToHubMixin): r""" Base class for all models. @@ -272,6 +274,8 @@ class ModelMixin(torch.nn.Module): save_function: Callable = None, safe_serialization: bool = False, variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, ): """ Save a model and its configuration file to a directory so that it can be reloaded using the @@ -292,6 +296,12 @@ class ModelMixin(torch.nn.Module): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") @@ -299,6 +309,15 @@ class ModelMixin(torch.nn.Module): os.makedirs(save_directory, exist_ok=True) + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + + # Only save the model itself if we are using distributed training model_to_save = self # Attach architecture to the config @@ -322,6 +341,15 @@ class ModelMixin(torch.nn.Module): logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py index 21fbc36c61..23a7af1e1b 100644 --- a/src/diffusers/pipelines/pipeline_flax_utils.py +++ b/src/diffusers/pipelines/pipeline_flax_utils.py @@ -23,14 +23,22 @@ import flax import numpy as np import PIL from flax.core.frozen_dict import FrozenDict -from huggingface_hub import snapshot_download +from huggingface_hub import create_repo, snapshot_download from PIL import Image from tqdm.auto import tqdm from ..configuration_utils import ConfigMixin from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin -from ..utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, http_user_agent, is_transformers_available, logging +from ..utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + BaseOutput, + PushToHubMixin, + http_user_agent, + is_transformers_available, + logging, +) if is_transformers_available(): @@ -90,7 +98,7 @@ class FlaxImagePipelineOutput(BaseOutput): images: Union[List[PIL.Image.Image], np.ndarray] -class FlaxDiffusionPipeline(ConfigMixin): +class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for Flax-based pipelines. @@ -139,7 +147,13 @@ class FlaxDiffusionPipeline(ConfigMixin): # set models setattr(self, name, module) - def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union[Dict, FrozenDict]): + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + params: Union[Dict, FrozenDict], + push_to_hub: bool = False, + **kwargs, + ): # TODO: handle inference_state """ Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its @@ -149,6 +163,12 @@ class FlaxDiffusionPipeline(ConfigMixin): Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ self.save_config(save_directory) @@ -157,6 +177,14 @@ class FlaxDiffusionPipeline(ConfigMixin): model_index_dict.pop("_diffusers_version") model_index_dict.pop("_module", None) + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) if sub_model is None: @@ -188,6 +216,15 @@ class FlaxDiffusionPipeline(ConfigMixin): else: save_method(os.path.join(save_directory, pipeline_component_name)) + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index fdcb7029cf..75cc0eae8c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -28,7 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch -from huggingface_hub import ModelCard, hf_hub_download, model_info, snapshot_download +from huggingface_hub import ModelCard, create_repo, hf_hub_download, model_info, snapshot_download from packaging import version from requests.exceptions import HTTPError from tqdm.auto import tqdm @@ -66,7 +66,7 @@ if is_transformers_available(): from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME -from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME +from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, PushToHubMixin if is_accelerate_available(): @@ -472,7 +472,7 @@ def load_sub_model( return loaded_sub_model -class DiffusionPipeline(ConfigMixin): +class DiffusionPipeline(ConfigMixin, PushToHubMixin): r""" Base class for all pipelines. @@ -558,6 +558,8 @@ class DiffusionPipeline(ConfigMixin): save_directory: Union[str, os.PathLike], safe_serialization: bool = False, variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, ): """ Save all saveable variables of the pipeline to a directory. A pipeline variable can be saved and loaded if its @@ -571,6 +573,12 @@ class DiffusionPipeline(ConfigMixin): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): If specified, weights are saved in the format `pytorch_model..bin`. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ model_index_dict = dict(self.config) model_index_dict.pop("_class_name", None) @@ -578,6 +586,14 @@ class DiffusionPipeline(ConfigMixin): model_index_dict.pop("_module", None) model_index_dict.pop("_name_or_path", None) + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + private = kwargs.pop("private", False) + create_pr = kwargs.pop("create_pr", False) + token = kwargs.pop("token", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id + expected_modules, optional_kwargs = self._get_signature_keys(self) def is_saveable_module(name, value): @@ -641,6 +657,15 @@ class DiffusionPipeline(ConfigMixin): # finally save the config self.save_config(save_directory) + if push_to_hub: + self._upload_folder( + save_directory, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) + def to( self, torch_device: Optional[Union[str, torch.device]] = None, diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 246cc71a13..a97a2d61e4 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Optional, Union import torch -from ..utils import BaseOutput +from ..utils import BaseOutput, PushToHubMixin SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -60,7 +60,7 @@ class SchedulerOutput(BaseOutput): prev_sample: torch.FloatTensor -class SchedulerMixin: +class SchedulerMixin(PushToHubMixin): """ Base class for all schedulers. @@ -153,7 +153,13 @@ class SchedulerMixin: Args: save_directory (`str` or `os.PathLike`): - Directory to save a configuration JSON file to. Will be created if it doesn't exist. + Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 19ce5b8360..e2af382c82 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -21,7 +21,7 @@ from typing import Any, Dict, Optional, Tuple, Union import flax import jax.numpy as jnp -from ..utils import BaseOutput +from ..utils import BaseOutput, PushToHubMixin SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -53,7 +53,7 @@ class FlaxSchedulerOutput(BaseOutput): prev_sample: jnp.ndarray -class FlaxSchedulerMixin: +class FlaxSchedulerMixin(PushToHubMixin): """ Mixin containing common functions for the schedulers. @@ -156,6 +156,12 @@ class FlaxSchedulerMixin: Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 4a9045ddf3..9b710d214d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -37,6 +37,7 @@ from .doc_utils import replace_example_docstring from .dynamic_modules_utils import get_class_from_dynamic_module from .hub_utils import ( HF_HUB_OFFLINE, + PushToHubMixin, _add_variant, _get_model_file, extract_commit_hash, diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index eeb3d15d12..6fc0f30ec3 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -17,13 +17,22 @@ import os import re import sys +import tempfile import traceback import warnings from pathlib import Path from typing import Dict, Optional, Union from uuid import uuid4 -from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami +from huggingface_hub import ( + HfFolder, + ModelCard, + ModelCardData, + create_repo, + hf_hub_download, + upload_folder, + whoami, +) from huggingface_hub.file_download import REGEX_COMMIT_HASH from huggingface_hub.utils import ( EntryNotFoundError, @@ -359,3 +368,96 @@ def _get_model_file( f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a file named {weights_name}" ) + + +class PushToHubMixin: + """ + A Mixin containing the functionality to push a model/scheduler to the Hugging Face Hub. + """ + + def _upload_folder( + self, + working_dir: Union[str, os.PathLike], + repo_id: str, + token: Optional[str] = None, + commit_message: Optional[str] = None, + create_pr: bool = False, + ): + """ + Uploads all files in `working_dir` to `repo_id`. + """ + if commit_message is None: + if "Model" in self.__class__.__name__: + commit_message = "Upload model" + elif "Scheduler" in self.__class__.__name__: + commit_message = "Upload scheduler" + else: + commit_message = f"Upload {self.__class__.__name__}" + + logger.info(f"Uploading the files of {working_dir} to {repo_id}.") + return upload_folder( + repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr + ) + + def push_to_hub( + self, + repo_id: str, + commit_message: Optional[str] = None, + private: Optional[bool] = None, + token: Optional[str] = None, + create_pr: bool = False, + safe_serialization: bool = True, + variant: Optional[str] = None, + ) -> str: + """ + Upload the {object_files} to the 🤗 Hugging Face Hub. + + Parameters: + repo_id (`str`): + The name of the repository you want to push your {object} to. It should contain your organization name + when pushing to a given organization. `repo_id` can also be a path to a local directory. + commit_message (`str`, *optional*): + Message to commit while pushing. Will default to `"Upload {object}"`. + private (`bool`, *optional*): + Whether or not the repository created should be private. + token (`str`, *optional*): + The token to use as HTTP bearer authorization for remote files. The token generated when running + `huggingface-cli login` (stored in `~/.huggingface`). + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether or not to convert the model weights in safetensors format for safer serialization. + variant (`str`, *optional*): + If specified, weights are saved in the format `pytorch_model..bin`. + + Examples: + + ```python + from diffusers import UNet2DConditionModel + + unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet") + + # Push the `unet` to your namespace with the name "my-finetuned-unet". + unet.push_to_hub("my-finetuned-unet") + + # Push the {object} to an organization with the name "my-finetuned-unet". + unet.push_to_hub("your-org/my-finetuned-unet") + ``` + """ + repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id + + # Save all files. + save_kwargs = {"safe_serialization": safe_serialization} + if "Scheduler" not in self.__class__.__name__: + save_kwargs.update({"variant": variant}) + + with tempfile.TemporaryDirectory() as tmpdir: + self.save_pretrained(tmpdir, **save_kwargs) + + return self._upload_folder( + tmpdir, + repo_id, + token=token, + commit_message=commit_message, + create_pr=create_pr, + ) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index e27f9271ea..b9d1f924d7 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -18,18 +18,27 @@ import tempfile import traceback import unittest import unittest.mock as mock +import uuid from typing import Dict, List, Tuple import numpy as np import requests_mock import torch +from huggingface_hub import delete_repo from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0, XFormersAttnProcessor from diffusers.training_utils import EMAModel from diffusers.utils import logging, torch_device -from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, require_torch_gpu, run_test_in_subprocess +from diffusers.utils.testing_utils import ( + CaptureLogger, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, +) + +from ..others.test_utils import TOKEN, USER, is_staging_test # Will be run via run_test_in_subprocess @@ -563,3 +572,72 @@ class ModelTesterMixin: f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" " from `_deprecated_kwargs = []`" ) + + +@is_staging_test +class ModelPushToHubTester(unittest.TestCase): + identifier = uuid.uuid4() + repo_id = f"test-model-{identifier}" + org_repo_id = f"valid_org/{repo_id}-org" + + def test_push_to_hub(self): + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + model.push_to_hub(self.repo_id, token=TOKEN) + + new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN) + + new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}") + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) + + def test_push_to_hub_in_organization(self): + model = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + model.push_to_hub(self.org_repo_id, token=TOKEN) + + new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.org_repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id) + + new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id) + for p1, p2 in zip(model.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(self.org_repo_id, token=TOKEN) diff --git a/tests/others/test_utils.py b/tests/others/test_utils.py index 6e7cc095f8..9dc73c0a74 100755 --- a/tests/others/test_utils.py +++ b/tests/others/test_utils.py @@ -13,12 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest +from distutils.util import strtobool + +import pytest from diffusers import __version__ from diffusers.utils import deprecate +# Used to test the hub +USER = "__DUMMY_TRANSFORMERS_USER__" +ENDPOINT_STAGING = "https://hub-ci.huggingface.co" + +# Not critical, only usable on the sandboxed CI instance. +TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" + + class DeprecateTester(unittest.TestCase): higher_version = ".".join([str(int(__version__.split(".")[0]) + 1)] + __version__.split(".")[1:]) lower_version = "0.0.1" @@ -168,3 +180,34 @@ class DeprecateTester(unittest.TestCase): deprecate(("deprecated_arg", self.higher_version, "This message is better!!!"), standard_warn=False) assert str(warning.warning) == "This message is better!!!" assert "diffusers/tests/others/test_utils.py" in warning.filename + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) + + +def is_staging_test(test_case): + """ + Decorator marking a test as a staging test. + + Those tests will run using the staging environment of huggingface.co instead of the real model hub. + """ + if not _run_staging: + return unittest.skip("test is staging test")(test_case) + else: + return pytest.mark.is_staging_test()(test_case) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d9a2e04485..be7ae1a315 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2,23 +2,30 @@ import contextlib import gc import inspect import io +import json +import os import re import tempfile import unittest +import uuid from typing import Callable, Union import numpy as np import PIL import torch +from huggingface_hub import delete_repo +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer import diffusers -from diffusers import DiffusionPipeline +from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device +from ..others.test_utils import TOKEN, USER, is_staging_test + def to_np(tensor): if isinstance(tensor, torch.Tensor): @@ -795,6 +802,126 @@ class PipelineTesterMixin: assert out_cfg.shape == out_no_cfg.shape +@is_staging_test +class PipelinePushToHubTester(unittest.TestCase): + identifier = uuid.uuid4() + repo_id = f"test-pipeline-{identifier}" + org_repo_id = f"valid_org/{repo_id}-org" + + def get_pipeline_components(self): + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + + with tempfile.TemporaryDirectory() as tmpdir: + dummy_vocab = {"<|startoftext|>": 0, "<|endoftext|>": 1, "!": 2} + vocab_path = os.path.join(tmpdir, "vocab.json") + with open(vocab_path, "w") as f: + json.dump(dummy_vocab, f) + + merges = "Ġ t\nĠt h" + merges_path = os.path.join(tmpdir, "merges.txt") + with open(merges_path, "w") as f: + f.writelines(merges) + tokenizer = CLIPTokenizer(vocab_file=vocab_path, merges_file=merges_path) + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def test_push_to_hub(self): + components = self.get_pipeline_components() + pipeline = StableDiffusionPipeline(**components) + pipeline.push_to_hub(self.repo_id, token=TOKEN) + + new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}", subfolder="unet") + unet = components["unet"] + for p1, p2 in zip(unet.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN) + + new_model = UNet2DConditionModel.from_pretrained(f"{USER}/{self.repo_id}", subfolder="unet") + for p1, p2 in zip(unet.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(self.repo_id, token=TOKEN) + + def test_push_to_hub_in_organization(self): + components = self.get_pipeline_components() + pipeline = StableDiffusionPipeline(**components) + pipeline.push_to_hub(self.org_repo_id, token=TOKEN) + + new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id, subfolder="unet") + unet = components["unet"] + for p1, p2 in zip(unet.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.org_repo_id) + + # Push to hub via save_pretrained + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir, push_to_hub=True, token=TOKEN, repo_id=self.org_repo_id) + + new_model = UNet2DConditionModel.from_pretrained(self.org_repo_id, subfolder="unet") + for p1, p2 in zip(unet.parameters(), new_model.parameters()): + self.assertTrue(torch.equal(p1, p2)) + + # Reset repo + delete_repo(self.org_repo_id, token=TOKEN) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. diff --git a/tests/schedulers/test_schedulers.py b/tests/schedulers/test_schedulers.py index d9423d6219..3c34bfe039 100755 --- a/tests/schedulers/test_schedulers.py +++ b/tests/schedulers/test_schedulers.py @@ -17,10 +17,12 @@ import json import os import tempfile import unittest +import uuid from typing import Dict, List, Tuple import numpy as np import torch +from huggingface_hub import delete_repo import diffusers from diffusers import ( @@ -41,6 +43,8 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import torch_device from diffusers.utils.testing_utils import CaptureLogger +from ..others.test_utils import TOKEN, USER, is_staging_test + torch.backends.cuda.matmul.allow_tf32 = False @@ -720,3 +724,64 @@ class SchedulerCommonTest(unittest.TestCase): scheduler.does_not_exist assert str(error.exception) == f"'{type(scheduler).__name__}' object has no attribute 'does_not_exist'" + + +@is_staging_test +class SchedulerPushToHubTester(unittest.TestCase): + identifier = uuid.uuid4() + repo_id = f"test-scheduler-{identifier}" + org_repo_id = f"valid_org/{repo_id}-org" + + def test_push_to_hub(self): + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + scheduler.push_to_hub(self.repo_id, token=TOKEN) + scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}") + + assert type(scheduler) == type(scheduler_loaded) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.repo_id) + + # Push to hub via save_config + with tempfile.TemporaryDirectory() as tmp_dir: + scheduler.save_config(tmp_dir, repo_id=self.repo_id, push_to_hub=True, token=TOKEN) + + scheduler_loaded = DDIMScheduler.from_pretrained(f"{USER}/{self.repo_id}") + + assert type(scheduler) == type(scheduler_loaded) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.repo_id) + + def test_push_to_hub_in_organization(self): + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + scheduler.push_to_hub(self.org_repo_id, token=TOKEN) + scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id) + + assert type(scheduler) == type(scheduler_loaded) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.org_repo_id) + + # Push to hub via save_config + with tempfile.TemporaryDirectory() as tmp_dir: + scheduler.save_config(tmp_dir, repo_id=self.org_repo_id, push_to_hub=True, token=TOKEN) + + scheduler_loaded = DDIMScheduler.from_pretrained(self.org_repo_id) + + assert type(scheduler) == type(scheduler_loaded) + + # Reset repo + delete_repo(token=TOKEN, repo_id=self.org_repo_id)