mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Proposal] Support saving to safetensors (#1494)
* Add parameter safe_serialization to DiffusionPipeline.save_pretrained * Add option safe_serialization on ModelMixin.save_pretrained * Add test test_save_safe_serialization * Black * Re-trigger the CI * Fix doc-builder * Validate files are saved as safetensor in test_save_safe_serialization
This commit is contained in:
@@ -191,7 +191,8 @@ class ModelMixin(torch.nn.Module):
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = torch.save,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
@@ -206,12 +207,21 @@ class ModelMixin(torch.nn.Module):
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method.
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
"""
|
||||
if safe_serialization and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
save_function = safetensors.torch.save_file if safe_serialization else torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = self
|
||||
@@ -224,18 +234,21 @@ class ModelMixin(torch.nn.Module):
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
|
||||
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
|
||||
save_function(state_dict, os.path.join(save_directory, weights_name))
|
||||
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
|
||||
@@ -188,7 +188,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = False,
|
||||
):
|
||||
"""
|
||||
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
|
||||
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
|
||||
@@ -197,6 +201,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
safe_serialization (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
"""
|
||||
self.save_config(save_directory)
|
||||
|
||||
@@ -234,7 +240,16 @@ class DiffusionPipeline(ConfigMixin):
|
||||
break
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||
|
||||
# Call the save method with the argument safe_serialization only if it's supported
|
||||
save_method_signature = inspect.signature(save_method)
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
if save_method_accept_safe:
|
||||
save_method(
|
||||
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
|
||||
)
|
||||
else:
|
||||
save_method(os.path.join(save_directory, pipeline_component_name))
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
|
||||
if torch_device is None:
|
||||
|
||||
@@ -25,6 +25,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
import safetensors.torch
|
||||
import transformers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMPipeline,
|
||||
@@ -537,6 +539,34 @@ class PipelineFastTests(unittest.TestCase):
|
||||
|
||||
assert dict(ddim_config) == dict(ddim_config_2)
|
||||
|
||||
def test_save_safe_serialization(self):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipeline.save_pretrained(tmpdirname, safe_serialization=True)
|
||||
|
||||
# Validate that the VAE safetensor exists and are of the correct format
|
||||
vae_path = os.path.join(tmpdirname, "vae", "diffusion_pytorch_model.safetensors")
|
||||
assert os.path.exists(vae_path), f"Could not find {vae_path}"
|
||||
_ = safetensors.torch.load_file(vae_path)
|
||||
|
||||
# Validate that the UNet safetensor exists and are of the correct format
|
||||
unet_path = os.path.join(tmpdirname, "unet", "diffusion_pytorch_model.safetensors")
|
||||
assert os.path.exists(unet_path), f"Could not find {unet_path}"
|
||||
_ = safetensors.torch.load_file(unet_path)
|
||||
|
||||
# Validate that the text encoder safetensor exists and are of the correct format
|
||||
text_encoder_path = os.path.join(tmpdirname, "text_encoder", "model.safetensors")
|
||||
if transformers.__version__ >= "4.25.1":
|
||||
assert os.path.exists(text_encoder_path), f"Could not find {text_encoder_path}"
|
||||
_ = safetensors.torch.load_file(text_encoder_path)
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
assert pipeline.unet is not None
|
||||
assert pipeline.vae is not None
|
||||
assert pipeline.text_encoder is not None
|
||||
assert pipeline.scheduler is not None
|
||||
assert pipeline.feature_extractor is not None
|
||||
|
||||
def test_optional_components(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
|
||||
Reference in New Issue
Block a user