1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix bug when variant and safetensor file does not match (#11587)

* Apply style fixes

* init test

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* add the variant check when there are no component folders

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update related test cases

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update related unit test cases

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* adjust

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* Apply style fixes

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
kaixuanliu
2025-05-26 16:48:41 +08:00
committed by GitHub
parent 7ae546f8d1
commit b5c2050a16
3 changed files with 100 additions and 49 deletions

View File

@@ -92,7 +92,7 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
"""
Checking for safetensors compatibility:
- The model is safetensors compatible only if there is a safetensors file for each model component present in
@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
extension is replaced with ".safetensors"
"""
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]
if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
# .bin, .safetensors, ...
weight_suffixs = [w.split(".")[-1] for w in weight_names]
# -00001-of-00002
transformers_index_format = r"\d{5}-of-\d{5}"
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
)
non_variant_file_re = re.compile(
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
)
passed_components = passed_components or []
if folder_names:
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
@@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
# If there are no component folders check the main directory for safetensors files
if not components:
return any(".safetensors" in filename for filename in filenames)
if variant is not None:
filtered_filenames = filter_with_regex(filenames, variant_file_re)
else:
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
return any(".safetensors" in filename for filename in filtered_filenames)
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
for component, component_filenames in components.items():
matches = []
for component_filename in component_filenames:
if variant is not None:
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
else:
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
for component_filename in filtered_component_filenames:
filename, extension = os.path.splitext(component_filename)
match_exists = extension == ".safetensors"
@@ -159,6 +192,10 @@ def filter_model_files(filenames):
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
weight_names = [
WEIGHTS_NAME,
@@ -207,9 +244,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
# interested in the extension name
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
def filter_with_regex(filenames, pattern_re):
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
# Group files by component
components = {}
for filename in filenames:
@@ -997,7 +1031,7 @@ def _get_ignore_patterns(
use_safetensors
and not allow_pickle
and not is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
)
):
raise EnvironmentError(
@@ -1008,7 +1042,7 @@ def _get_ignore_patterns(
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif use_safetensors and is_safetensors_compatible(
model_filenames, passed_components=passed_components, folder_names=model_folder_names
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
):
ignore_patterns = ["*.bin", "*.msgpack"]

View File

@@ -87,21 +87,24 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_compatible_variant_mixed(self):
filenames = [
"unet/diffusion_pytorch_model.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_model_is_not_compatible_variant(self):
filenames = [
@@ -121,7 +124,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_transformer_model_is_not_compatible_variant(self):
filenames = [
@@ -145,7 +149,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.bin",
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
filenames = [
@@ -173,7 +178,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"text_encoder/model.fp16-00001-of-00002.safetensors",
"text_encoder/model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_sharded(self):
filenames = [
@@ -189,13 +195,15 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_only_variants(self):
filenames = [
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
self.assertFalse(is_safetensors_compatible(filenames))
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
def test_diffusers_is_compatible_no_components(self):
filenames = [

View File

@@ -538,26 +538,38 @@ class DownloadTests(unittest.TestCase):
variant = "no_ema"
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
if use_safetensors:
with self.assertRaises(OSError) as error_context:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
else:
tmpdirname = StableDiffusionPipeline.download(
"hf-internal-testing/stable-diffusion-all-variants",
cache_dir=tmpdirname,
variant=variant,
use_safetensors=use_safetensors,
)
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
files = [item for sublist in all_root_files for item in sublist]
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
assert not any(f.endswith(other_format) for f in files)
# Some of the downloaded files should be a non-variant file, check:
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
# only unet has "no_ema" variant
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
# vae, safety_checker and text_encoder should have no variant
assert (
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
)
assert not any(f.endswith(other_format) for f in files)
def test_download_variants_with_sharded_checkpoints(self):
# Here we test for downloading of "variant" files belonging to the `unet` and
@@ -588,20 +600,17 @@ class DownloadTests(unittest.TestCase):
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
for is_local in [True, False]:
with CaptureLogger(logger) as cap_logger:
with tempfile.TemporaryDirectory() as tmpdirname:
local_repo_id = repo_id
if is_local:
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
with CaptureLogger(logger) as cap_logger:
with tempfile.TemporaryDirectory() as tmpdirname:
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
_ = DiffusionPipeline.from_pretrained(
local_repo_id,
safety_checker=None,
variant="fp16",
use_safetensors=True,
)
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
_ = DiffusionPipeline.from_pretrained(
local_repo_id,
safety_checker=None,
variant="fp16",
use_safetensors=True,
)
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
def test_download_safetensors_only_variant_exists_for_model(self):
variant = None
@@ -616,7 +625,7 @@ class DownloadTests(unittest.TestCase):
variant=variant,
use_safetensors=use_safetensors,
)
assert "Error no file name" in str(error_context.exception)
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
# text encoder has fp16 variants so we can load it
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -675,7 +684,7 @@ class DownloadTests(unittest.TestCase):
use_safetensors=use_safetensors,
)
assert "Error no file name" in str(error_context.exception)
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
def test_download_bin_variant_does_not_exist_for_model(self):
variant = "no_ema"