diff --git a/setup.py b/setup.py index 9148acce26..4ebec86927 100644 --- a/setup.py +++ b/setup.py @@ -97,6 +97,7 @@ _deps = [ "pytest", "pytest-timeout", "pytest-xdist", + "safetensors", "sentencepiece>=0.1.91,!=0.1.92", "scipy", "regex!=2019.12.17", @@ -184,10 +185,11 @@ extras["test"] = deps_list( "pytest", "pytest-timeout", "pytest-xdist", + "safetensors", "sentencepiece", "scipy", "torchvision", - "transformers" + "transformers", ) extras["torch"] = deps_list("torch", "accelerate") diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index d187b79145..2fd6bfa1fa 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -21,6 +21,7 @@ deps = { "pytest": "pytest", "pytest-timeout": "pytest-timeout", "pytest-xdist": "pytest-xdist", + "safetensors": "safetensors", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "scipy": "scipy", "regex": "regex!=2019.12.17", diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 8cb0acf52f..5f79e7fe01 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -30,8 +30,10 @@ from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_accelerate_available, + is_safetensors_available, is_torch_version, logging, ) @@ -51,6 +53,9 @@ if is_accelerate_available(): from accelerate.utils import set_module_tensor_to_device from accelerate.utils.versions import is_torch_version +if is_safetensors_available(): + import safetensors + def get_parameter_device(parameter: torch.nn.Module): try: @@ -84,10 +89,13 @@ def get_parameter_dtype(parameter: torch.nn.Module): def load_state_dict(checkpoint_file: Union[str, os.PathLike]): """ - Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + Reads a checkpoint file, returning properly formatted errors if they arise. """ try: - return torch.load(checkpoint_file, map_location="cpu") + if os.path.basename(checkpoint_file) == WEIGHTS_NAME: + return torch.load(checkpoint_file, map_location="cpu") + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") except Exception as e: try: with open(checkpoint_file) as f: @@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) from e except (UnicodeDecodeError, ValueError): raise OSError( - f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." ) @@ -375,75 +383,39 @@ class ModelMixin(torch.nn.Module): # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # Load model - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): - # Load from a PyTorch checkpoint - model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) - elif subfolder is not None and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) - ): - model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME) - else: - raise EnvironmentError( - f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." - ) - else: + + model_file = None + if is_safetensors_available(): try: - # Load from URL or cache if already cached - model_file = hf_hub_download( + model_file = _get_model_file( pretrained_model_name_or_path, - filename=WEIGHTS_NAME, + weights_name=SAFETENSORS_WEIGHTS_NAME, cache_dir=cache_dir, force_download=force_download, - proxies=proxies, resume_download=resume_download, + proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, - user_agent=user_agent, - subfolder=subfolder, revision=revision, + subfolder=subfolder, + user_agent=user_agent, ) - - except RepositoryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " - "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " - "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " - "login`." - ) - except RevisionNotFoundError: - raise EnvironmentError( - f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " - "this model name. Check the model page at " - f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." - ) - except EntryNotFoundError: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}." - ) - except HTTPError as err: - raise EnvironmentError( - "There was a specific connection error when trying to load" - f" {pretrained_model_name_or_path}:\n{err}" - ) - except ValueError: - raise EnvironmentError( - f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" - f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" - f" directory containing a file named {WEIGHTS_NAME} or" - " \nCheckout your internet connection or see how to run the library in" - " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." - ) - except EnvironmentError: - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " - "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " - f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " - f"containing a file named {WEIGHTS_NAME}" - ) - - # restore default dtype + except: + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) if low_cpu_mem_usage: # Instantiate model with empty weights @@ -691,3 +663,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: return unwrap_model(model.module) else: return model + + +def _get_model_file( + pretrained_model_name_or_path, + *, + weights_name, + subfolder, + cache_dir, + force_download, + proxies, + resume_download, + local_files_only, + use_auth_token, + user_agent, + revision, +): + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): + # Load from a PyTorch checkpoint + model_file = os.path.join(pretrained_model_name_or_path, weights_name) + return model_file + elif subfolder is not None and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + ): + model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) + return model_file + else: + raise EnvironmentError( + f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." + ) + else: + try: + # Load from URL or cache if already cached + model_file = hf_hub_download( + pretrained_model_name_or_path, + filename=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + subfolder=subfolder, + revision=revision, + ) + return model_file + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {weights_name} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {weights_name}" + ) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 35ebd536c5..5dab802ba8 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -26,7 +26,7 @@ import torch import diffusers import PIL -from huggingface_hub import snapshot_download +from huggingface_hub import model_info, snapshot_download from packaging import version from PIL import Image from tqdm.auto import tqdm @@ -44,6 +44,7 @@ from .utils import ( BaseOutput, deprecate, is_accelerate_available, + is_safetensors_available, is_torch_version, is_transformers_available, logging, @@ -117,6 +118,23 @@ class AudioPipelineOutput(BaseOutput): audios: np.ndarray +def is_safetensors_compatible(info) -> bool: + filenames = set(sibling.rfilename for sibling in info.siblings) + pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) + is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) + for pt_filename in pt_filenames: + prefix, raw = os.path.split(pt_filename) + if raw == "pytorch_model.bin": + # transformers specific + sf_filename = os.path.join(prefix, "model.safetensors") + else: + sf_filename = pt_filename[: -len(".bin")] + ".safetensors" + if sf_filename not in filenames: + logger.warning("{sf_filename} not found") + is_safetensors_compatible = False + return is_safetensors_compatible + + class DiffusionPipeline(ConfigMixin): r""" Base class for all models. @@ -459,7 +477,7 @@ class DiffusionPipeline(ConfigMixin): allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] # make sure we don't download flax weights - ignore_patterns = "*.msgpack" + ignore_patterns = ["*.msgpack"] if custom_pipeline is not None: allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] @@ -473,6 +491,15 @@ class DiffusionPipeline(ConfigMixin): user_agent["custom_pipeline"] = custom_pipeline user_agent = http_user_agent(user_agent) + if is_safetensors_available(): + info = model_info( + pretrained_model_name_or_path, + use_auth_token=use_auth_token, + revision=revision, + ) + if is_safetensors_compatible(info): + ignore_patterns.append("*.bin") + # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index e86f3b801a..3dba3a2bc2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -28,6 +28,7 @@ from .import_utils import ( is_inflect_available, is_modelcards_available, is_onnx_available, + is_safetensors_available, is_scipy_available, is_tf_available, is_torch_available, @@ -69,6 +70,7 @@ CONFIG_NAME = "config.json" WEIGHTS_NAME = "diffusion_pytorch_model.bin" FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack" ONNX_WEIGHTS_NAME = "model.onnx" +SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors" ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb" HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index c0294b4a3d..86d5879080 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -42,6 +42,7 @@ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper() STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} @@ -55,7 +56,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA except importlib_metadata.PackageNotFoundError: _torch_available = False else: - logger.info("Disabling PyTorch because USE_TF is set") + logger.info("Disabling PyTorch because USE_TORCH is set") _torch_available = False @@ -109,6 +110,17 @@ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: else: _flax_available = False +if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: + _safetensors_available = importlib.util.find_spec("safetensors") is not None + if _safetensors_available: + try: + _safetensors_version = importlib_metadata.version("safetensors") + logger.info(f"Safetensors version {_safetensors_version} available.") + except importlib_metadata.PackageNotFoundError: + _safetensors_available = False +else: + logger.info("Disabling Safetensors because USE_TF is set") + _safetensors_available = False _transformers_available = importlib.util.find_spec("transformers") is not None try: @@ -190,6 +202,10 @@ def is_torch_available(): return _torch_available +def is_safetensors_available(): + return _safetensors_available + + def is_tf_available(): return _tf_available diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 0aad9de8be..033f363ff4 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -92,6 +92,24 @@ class DownloadTests(unittest.TestCase): # None of the downloaded files should be a flax file even if we have some here: # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack assert not any(f.endswith(".msgpack") for f in files) + # We need to never convert this tiny model to safetensors for this test to pass + assert not any(f.endswith(".safetensors") for f in files) + + def test_download_safetensors(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + _ = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe-safetensors", + safety_checker=None, + cache_dir=tmpdirname, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a pytorch file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack + assert not any(f.endswith(".bin") for f in files) def test_download_no_safety_checker(self): prompt = "hello"