mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Revert save_model in ModelMixin save_pretrained and use safe_serialization=False in test (#11196)
This commit is contained in:
@@ -714,10 +714,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
try:
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
except RuntimeError:
|
||||
safetensors.torch.save_model(model_to_save, filepath, metadata={"format": "pt"})
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
|
||||
@@ -2293,7 +2293,7 @@ class PipelineTesterMixin:
|
||||
specified_key = next(iter(components.keys()))
|
||||
|
||||
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype_dict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user