mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -89,7 +89,7 @@ for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, passed_components=None) -> bool:
|
||||
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- The model is safetensors compatible only if there is a safetensors file for each model component present in
|
||||
@@ -101,6 +101,8 @@ def is_safetensors_compatible(filenames, passed_components=None) -> bool:
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
passed_components = passed_components or []
|
||||
if folder_names is not None:
|
||||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# extract all components of the pipeline and their associated files
|
||||
components = {}
|
||||
|
||||
@@ -1416,14 +1416,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if (
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(model_filenames, passed_components=passed_components)
|
||||
and not is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
|
||||
)
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif use_safetensors and is_safetensors_compatible(model_filenames, passed_components=passed_components):
|
||||
elif use_safetensors and is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
||||
|
||||
@@ -116,6 +116,30 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
|
||||
def test_transformer_model_is_compatible_variant_extra_folder(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
|
||||
|
||||
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
|
||||
filenames = [
|
||||
"safety_checker/pytorch_model.fp16.bin",
|
||||
"safety_checker/model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.fp16.bin",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"text_encoder"}))
|
||||
|
||||
def test_transformers_is_compatible_sharded(self):
|
||||
filenames = [
|
||||
"text_encoder/pytorch_model.bin",
|
||||
|
||||
Reference in New Issue
Block a user