From ae368e42d2c6ec5ef03b7a40338a418f186799eb Mon Sep 17 00:00:00 2001 From: Matthieu Bizien Date: Fri, 2 Dec 2022 18:33:16 +0100 Subject: [PATCH] [Proposal] Support saving to safetensors (#1494) * Add parameter safe_serialization to DiffusionPipeline.save_pretrained * Add option safe_serialization on ModelMixin.save_pretrained * Add test test_save_safe_serialization * Black * Re-trigger the CI * Fix doc-builder * Validate files are saved as safetensor in test_save_safe_serialization --- src/diffusers/modeling_utils.py | 23 ++++++++++++++++++----- src/diffusers/pipeline_utils.py | 19 +++++++++++++++++-- tests/test_pipelines.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 7 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 8f0222957a..e270f75e05 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -191,7 +191,8 @@ class ModelMixin(torch.nn.Module): self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, - save_function: Callable = torch.save, + save_function: Callable = None, + safe_serialization: bool = False, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the @@ -206,12 +207,21 @@ class ModelMixin(torch.nn.Module): the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. + need to replace `torch.save` by another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + if save_function is None: + save_function = safetensors.torch.save_file if safe_serialization else torch.save + os.makedirs(save_directory, exist_ok=True) model_to_save = self @@ -224,18 +234,21 @@ class ModelMixin(torch.nn.Module): # Save the model state_dict = model_to_save.state_dict() + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + # Clean the folder from a previous save for filename in os.listdir(save_directory): full_filename = os.path.join(save_directory, filename) # If we have a shard file that is not going to be replaced, we delete it, but only from the main process # in distributed settings to avoid race conditions. - if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process: + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process: os.remove(full_filename) # Save the model - save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) + save_function(state_dict, os.path.join(save_directory, weights_name)) - logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") + logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 056ca4fa73..e65d55e20c 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -188,7 +188,11 @@ class DiffusionPipeline(ConfigMixin): # set models setattr(self, name, module) - def save_pretrained(self, save_directory: Union[str, os.PathLike]): + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + ): """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading @@ -197,6 +201,8 @@ class DiffusionPipeline(ConfigMixin): Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ self.save_config(save_directory) @@ -234,7 +240,16 @@ class DiffusionPipeline(ConfigMixin): break save_method = getattr(sub_model, save_method_name) - save_method(os.path.join(save_directory, pipeline_component_name)) + + # Call the save method with the argument safe_serialization only if it's supported + save_method_signature = inspect.signature(save_method) + save_method_accept_safe = "safe_serialization" in save_method_signature.parameters + if save_method_accept_safe: + save_method( + os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization + ) + else: + save_method(os.path.join(save_directory, pipeline_component_name)) def to(self, torch_device: Optional[Union[str, torch.device]] = None): if torch_device is None: diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 630cec65ff..072fe2fe76 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -25,6 +25,8 @@ import numpy as np import torch import PIL +import safetensors.torch +import transformers from diffusers import ( AutoencoderKL, DDIMPipeline, @@ -537,6 +539,34 @@ class PipelineFastTests(unittest.TestCase): assert dict(ddim_config) == dict(ddim_config_2) + def test_save_safe_serialization(self): + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") + with tempfile.TemporaryDirectory() as tmpdirname: + pipeline.save_pretrained(tmpdirname, safe_serialization=True) + + # Validate that the VAE safetensor exists and are of the correct format + vae_path = os.path.join(tmpdirname, "vae", "diffusion_pytorch_model.safetensors") + assert os.path.exists(vae_path), f"Could not find {vae_path}" + _ = safetensors.torch.load_file(vae_path) + + # Validate that the UNet safetensor exists and are of the correct format + unet_path = os.path.join(tmpdirname, "unet", "diffusion_pytorch_model.safetensors") + assert os.path.exists(unet_path), f"Could not find {unet_path}" + _ = safetensors.torch.load_file(unet_path) + + # Validate that the text encoder safetensor exists and are of the correct format + text_encoder_path = os.path.join(tmpdirname, "text_encoder", "model.safetensors") + if transformers.__version__ >= "4.25.1": + assert os.path.exists(text_encoder_path), f"Could not find {text_encoder_path}" + _ = safetensors.torch.load_file(text_encoder_path) + + pipeline = StableDiffusionPipeline.from_pretrained(tmpdirname) + assert pipeline.unet is not None + assert pipeline.vae is not None + assert pipeline.text_encoder is not None + assert pipeline.scheduler is not None + assert pipeline.feature_extractor is not None + def test_optional_components(self): unet = self.dummy_cond_unet() pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")