1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix a regression in is_safetensors_compatible (#9234)

fix
This commit is contained in:
YiYi Xu
2024-08-21 03:26:55 -10:00
committed by GitHub
parent 867e0c919e
commit 214372aa99
3 changed files with 33 additions and 3 deletions

View File

@@ -89,7 +89,7 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool:
extension is replaced with ".safetensors"
"""
passed_components = passed_components or []
if folder_names is not None:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
# extract all components of the pipeline and their associated files
components = {}

View File

@@ -1416,14 +1416,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if (
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
)
):
raise EnvironmentError(
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
):
ignore_patterns = ["*.bin", "*.msgpack"]
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx

View File

@@ -116,6 +116,30 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
]
self.assertFalse(is_safetensors_compatible(filenames))
def test_transformer_model_is_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
"safety_checker/pytorch_model.fp16.bin",
"safety_checker/model.fp16.safetensors",
"vae/diffusion_pytorch_model.fp16.bin",
"vae/diffusion_pytorch_model.fp16.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"}))
def test_transformers_is_compatible_sharded(self):
filenames = [
"text_encoder/pytorch_model.bin",