mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Proposal] Support loading from safetensors if file is present. (#1357)
* [Proposal] Support loading from safetensors if file is present. * Style. * Fix. * Adding some test to check loading logic. + modify download logic to not download pytorch file if not necessary. * Fixing the logic. * Adressing comments. * factor out into a function. * Remove dead function. * Typo. * Extra fetch only if safetensors is there. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
4
setup.py
4
setup.py
@@ -97,6 +97,7 @@ _deps = [
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"safetensors",
|
||||
"sentencepiece>=0.1.91,!=0.1.92",
|
||||
"scipy",
|
||||
"regex!=2019.12.17",
|
||||
@@ -184,10 +185,11 @@ extras["test"] = deps_list(
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
"safetensors",
|
||||
"sentencepiece",
|
||||
"scipy",
|
||||
"torchvision",
|
||||
"transformers"
|
||||
"transformers",
|
||||
)
|
||||
extras["torch"] = deps_list("torch", "accelerate")
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ deps = {
|
||||
"pytest": "pytest",
|
||||
"pytest-timeout": "pytest-timeout",
|
||||
"pytest-xdist": "pytest-xdist",
|
||||
"safetensors": "safetensors",
|
||||
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
||||
"scipy": "scipy",
|
||||
"regex": "regex!=2019.12.17",
|
||||
|
||||
@@ -30,8 +30,10 @@ from .utils import (
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -51,6 +53,9 @@ if is_accelerate_available():
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
@@ -84,10 +89,13 @@ def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
if os.path.basename(checkpoint_file) == WEIGHTS_NAME:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
@@ -104,7 +112,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' "
|
||||
f"at '{checkpoint_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
)
|
||||
@@ -375,75 +383,39 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# Load model
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
):
|
||||
model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
|
||||
model_file = None
|
||||
if is_safetensors_available():
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
model_file = hf_hub_download(
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
filename=WEIGHTS_NAME,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
# restore default dtype
|
||||
except:
|
||||
pass
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
@@ -691,3 +663,88 @@ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
*,
|
||||
weights_name,
|
||||
subfolder,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
resume_download,
|
||||
local_files_only,
|
||||
use_auth_token,
|
||||
user_agent,
|
||||
revision,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
return model_file
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
):
|
||||
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
return model_file
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
return model_file
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {weights_name} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {weights_name}"
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ import torch
|
||||
|
||||
import diffusers
|
||||
import PIL
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import model_info, snapshot_download
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
@@ -44,6 +44,7 @@ from .utils import (
|
||||
BaseOutput,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -117,6 +118,23 @@ class AudioPipelineOutput(BaseOutput):
|
||||
audios: np.ndarray
|
||||
|
||||
|
||||
def is_safetensors_compatible(info) -> bool:
|
||||
filenames = set(sibling.rfilename for sibling in info.siblings)
|
||||
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
|
||||
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
|
||||
for pt_filename in pt_filenames:
|
||||
prefix, raw = os.path.split(pt_filename)
|
||||
if raw == "pytorch_model.bin":
|
||||
# transformers specific
|
||||
sf_filename = os.path.join(prefix, "model.safetensors")
|
||||
else:
|
||||
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
|
||||
if sf_filename not in filenames:
|
||||
logger.warning("{sf_filename} not found")
|
||||
is_safetensors_compatible = False
|
||||
return is_safetensors_compatible
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
@@ -459,7 +477,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
|
||||
|
||||
# make sure we don't download flax weights
|
||||
ignore_patterns = "*.msgpack"
|
||||
ignore_patterns = ["*.msgpack"]
|
||||
|
||||
if custom_pipeline is not None:
|
||||
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
|
||||
@@ -473,6 +491,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
if is_safetensors_available():
|
||||
info = model_info(
|
||||
pretrained_model_name_or_path,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
if is_safetensors_compatible(info):
|
||||
ignore_patterns.append("*.bin")
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
|
||||
@@ -28,6 +28,7 @@ from .import_utils import (
|
||||
is_inflect_available,
|
||||
is_modelcards_available,
|
||||
is_onnx_available,
|
||||
is_safetensors_available,
|
||||
is_scipy_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
@@ -69,6 +70,7 @@ CONFIG_NAME = "config.json"
|
||||
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
|
||||
FLAX_WEIGHTS_NAME = "diffusion_flax_model.msgpack"
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
|
||||
@@ -42,6 +42,7 @@ ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
|
||||
USE_TF = os.environ.get("USE_TF", "AUTO").upper()
|
||||
USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
|
||||
USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
|
||||
USE_SAFETENSORS = os.environ.get("USE_SAFETENSORS", "AUTO").upper()
|
||||
|
||||
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
|
||||
|
||||
@@ -55,7 +56,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torch_available = False
|
||||
else:
|
||||
logger.info("Disabling PyTorch because USE_TF is set")
|
||||
logger.info("Disabling PyTorch because USE_TORCH is set")
|
||||
_torch_available = False
|
||||
|
||||
|
||||
@@ -109,6 +110,17 @@ if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
else:
|
||||
_flax_available = False
|
||||
|
||||
if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES:
|
||||
_safetensors_available = importlib.util.find_spec("safetensors") is not None
|
||||
if _safetensors_available:
|
||||
try:
|
||||
_safetensors_version = importlib_metadata.version("safetensors")
|
||||
logger.info(f"Safetensors version {_safetensors_version} available.")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_safetensors_available = False
|
||||
else:
|
||||
logger.info("Disabling Safetensors because USE_TF is set")
|
||||
_safetensors_available = False
|
||||
|
||||
_transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
try:
|
||||
@@ -190,6 +202,10 @@ def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
|
||||
def is_safetensors_available():
|
||||
return _safetensors_available
|
||||
|
||||
|
||||
def is_tf_available():
|
||||
return _tf_available
|
||||
|
||||
|
||||
@@ -92,6 +92,24 @@ class DownloadTests(unittest.TestCase):
|
||||
# None of the downloaded files should be a flax file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
|
||||
assert not any(f.endswith(".msgpack") for f in files)
|
||||
# We need to never convert this tiny model to safetensors for this test to pass
|
||||
assert not any(f.endswith(".safetensors") for f in files)
|
||||
|
||||
def test_download_safetensors(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# pipeline has Flax weights
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe-safetensors",
|
||||
safety_checker=None,
|
||||
cache_dir=tmpdirname,
|
||||
)
|
||||
|
||||
all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
# None of the downloaded files should be a pytorch file even if we have some here:
|
||||
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
|
||||
assert not any(f.endswith(".bin") for f in files)
|
||||
|
||||
def test_download_no_safety_checker(self):
|
||||
prompt = "hello"
|
||||
|
||||
Reference in New Issue
Block a user