From da857bebb604676938e141b48e8791bbc38df209 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 2 Apr 2025 12:45:36 +0100 Subject: [PATCH] Revert `save_model` in ModelMixin save_pretrained and use safe_serialization=False in test (#11196) --- src/diffusers/models/modeling_utils.py | 5 +---- tests/pipelines/test_pipelines_common.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 814547d82b..19ac868cda 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -714,10 +714,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. - try: - safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) - except RuntimeError: - safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"}) + safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"}) else: torch.save(shard, filepath) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index cc5008e372..d3e39e363f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -2293,7 +2293,7 @@ class PipelineTesterMixin: specified_key = next(iter(components.keys())) with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname: - pipe.save_pretrained(tmpdirname) + pipe.save_pretrained(tmpdirname, safe_serialization=False) torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16} loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)