1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Proposal] Support saving to safetensors (#1494)

* Add parameter safe_serialization to DiffusionPipeline.save_pretrained

* Add option safe_serialization on ModelMixin.save_pretrained

* Add test test_save_safe_serialization

* Black

* Re-trigger the CI

* Fix doc-builder

* Validate files are saved as safetensor in test_save_safe_serialization
This commit is contained in:
Matthieu Bizien
2022-12-02 18:33:16 +01:00
committed by GitHub
parent cf4664e885
commit ae368e42d2
3 changed files with 65 additions and 7 deletions

View File

@@ -191,7 +191,8 @@ class ModelMixin(torch.nn.Module):
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = torch.save,
save_function: Callable = None,
safe_serialization: bool = False,
):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
@@ -206,12 +207,21 @@ class ModelMixin(torch.nn.Module):
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method.
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
save_function = safetensors.torch.save_file if safe_serialization else torch.save
os.makedirs(save_directory, exist_ok=True)
model_to_save = self
@@ -224,18 +234,21 @@ class ModelMixin(torch.nn.Module):
# Save the model
state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
# Save the model
save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
save_function(state_dict, os.path.join(save_directory, weights_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):

View File

@@ -188,7 +188,11 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
):
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
@@ -197,6 +201,8 @@ class DiffusionPipeline(ConfigMixin):
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
self.save_config(save_directory)
@@ -234,7 +240,16 @@ class DiffusionPipeline(ConfigMixin):
break
save_method = getattr(sub_model, save_method_name)
save_method(os.path.join(save_directory, pipeline_component_name))
# Call the save method with the argument safe_serialization only if it's supported
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
if save_method_accept_safe:
save_method(
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
)
else:
save_method(os.path.join(save_directory, pipeline_component_name))
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:

View File

@@ -25,6 +25,8 @@ import numpy as np
import torch
import PIL
import safetensors.torch
import transformers
from diffusers import (
AutoencoderKL,
DDIMPipeline,
@@ -537,6 +539,34 @@ class PipelineFastTests(unittest.TestCase):
assert dict(ddim_config) == dict(ddim_config_2)
def test_save_safe_serialization(self):
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
with tempfile.TemporaryDirectory() as tmpdirname:
pipeline.save_pretrained(tmpdirname, safe_serialization=True)
# Validate that the VAE safetensor exists and are of the correct format
vae_path = os.path.join(tmpdirname, "vae", "diffusion_pytorch_model.safetensors")
assert os.path.exists(vae_path), f"Could not find {vae_path}"
_ = safetensors.torch.load_file(vae_path)
# Validate that the UNet safetensor exists and are of the correct format
unet_path = os.path.join(tmpdirname, "unet", "diffusion_pytorch_model.safetensors")
assert os.path.exists(unet_path), f"Could not find {unet_path}"
_ = safetensors.torch.load_file(unet_path)
# Validate that the text encoder safetensor exists and are of the correct format
text_encoder_path = os.path.join(tmpdirname, "text_encoder", "model.safetensors")
if transformers.__version__ >= "4.25.1":
assert os.path.exists(text_encoder_path), f"Could not find {text_encoder_path}"
_ = safetensors.torch.load_file(text_encoder_path)
pipeline = StableDiffusionPipeline.from_pretrained(tmpdirname)
assert pipeline.unet is not None
assert pipeline.vae is not None
assert pipeline.text_encoder is not None
assert pipeline.scheduler is not None
assert pipeline.feature_extractor is not None
def test_optional_components(self):
unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")