mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix bug when variant and safetensor file does not match (#11587)
* Apply style fixes * init test Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * add the variant check when there are no component folders Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update related test cases Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * update related unit test cases Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * adjust Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> * Apply style fixes --------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -92,7 +92,7 @@ for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
|
||||
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None, variant=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- The model is safetensors compatible only if there is a safetensors file for each model component present in
|
||||
@@ -103,6 +103,31 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
|
||||
|
||||
# model_pytorch, diffusion_model_pytorch, ...
|
||||
weight_prefixes = [w.split(".")[0] for w in weight_names]
|
||||
# .bin, .safetensors, ...
|
||||
weight_suffixs = [w.split(".")[-1] for w in weight_names]
|
||||
# -00001-of-00002
|
||||
transformers_index_format = r"\d{5}-of-\d{5}"
|
||||
# `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors`
|
||||
variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
non_variant_file_re = re.compile(
|
||||
rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$"
|
||||
)
|
||||
|
||||
passed_components = passed_components or []
|
||||
if folder_names:
|
||||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
||||
@@ -122,14 +147,22 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
|
||||
|
||||
# If there are no component folders check the main directory for safetensors files
|
||||
if not components:
|
||||
return any(".safetensors" in filename for filename in filenames)
|
||||
if variant is not None:
|
||||
filtered_filenames = filter_with_regex(filenames, variant_file_re)
|
||||
else:
|
||||
filtered_filenames = filter_with_regex(filenames, non_variant_file_re)
|
||||
return any(".safetensors" in filename for filename in filtered_filenames)
|
||||
|
||||
# iterate over all files of a component
|
||||
# check if safetensor files exist for that component
|
||||
# if variant is provided check if the variant of the safetensors exists
|
||||
for component, component_filenames in components.items():
|
||||
matches = []
|
||||
for component_filename in component_filenames:
|
||||
if variant is not None:
|
||||
filtered_component_filenames = filter_with_regex(component_filenames, variant_file_re)
|
||||
else:
|
||||
filtered_component_filenames = filter_with_regex(component_filenames, non_variant_file_re)
|
||||
for component_filename in filtered_component_filenames:
|
||||
filename, extension = os.path.splitext(component_filename)
|
||||
|
||||
match_exists = extension == ".safetensors"
|
||||
@@ -159,6 +192,10 @@ def filter_model_files(filenames):
|
||||
return [f for f in filenames if any(f.endswith(extension) for extension in allowed_extensions)]
|
||||
|
||||
|
||||
def filter_with_regex(filenames, pattern_re):
|
||||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
|
||||
|
||||
|
||||
def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -> Union[List[os.PathLike], str]:
|
||||
weight_names = [
|
||||
WEIGHTS_NAME,
|
||||
@@ -207,9 +244,6 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
|
||||
# interested in the extension name
|
||||
return {f for f in filenames if not any(f.endswith(pat.lstrip("*.")) for pat in ignore_patterns)}
|
||||
|
||||
def filter_with_regex(filenames, pattern_re):
|
||||
return {f for f in filenames if pattern_re.match(f.split("/")[-1]) is not None}
|
||||
|
||||
# Group files by component
|
||||
components = {}
|
||||
for filename in filenames:
|
||||
@@ -997,7 +1031,7 @@ def _get_ignore_patterns(
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
|
||||
)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
@@ -1008,7 +1042,7 @@ def _get_ignore_patterns(
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
|
||||
elif use_safetensors and is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names, variant=variant
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
|
||||
@@ -87,21 +87,24 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_diffusers_model_is_compatible_variant(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_diffusers_model_is_compatible_variant_mixed(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_diffusers_model_is_not_compatible_variant(self):
|
||||
filenames = [
|
||||
@@ -121,7 +124,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_transformer_model_is_not_compatible_variant(self):
|
||||
filenames = [
|
||||
@@ -145,7 +149,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.fp16.bin",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
|
||||
self.assertFalse(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, folder_names={"vae", "unet"}, variant="fp16"))
|
||||
|
||||
def test_transformer_model_is_not_compatible_variant_extra_folder(self):
|
||||
filenames = [
|
||||
@@ -173,7 +178,8 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"text_encoder/model.fp16-00001-of-00002.safetensors",
|
||||
"text_encoder/model.fp16-00001-of-00002.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_diffusers_is_compatible_sharded(self):
|
||||
filenames = [
|
||||
@@ -189,13 +195,15 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
|
||||
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
|
||||
"unet/diffusion_pytorch_model.fp16-00001-of-00002.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_diffusers_is_compatible_only_variants(self):
|
||||
filenames = [
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
]
|
||||
self.assertTrue(is_safetensors_compatible(filenames))
|
||||
self.assertFalse(is_safetensors_compatible(filenames))
|
||||
self.assertTrue(is_safetensors_compatible(filenames, variant="fp16"))
|
||||
|
||||
def test_diffusers_is_compatible_no_components(self):
|
||||
filenames = [
|
||||
|
||||
@@ -538,26 +538,38 @@ class DownloadTests(unittest.TestCase):
|
||||
variant = "no_ema"
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
if use_safetensors:
|
||||
with self.assertRaises(OSError) as error_context:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
|
||||
else:
|
||||
tmpdirname = StableDiffusionPipeline.download(
|
||||
"hf-internal-testing/stable-diffusion-all-variants",
|
||||
cache_dir=tmpdirname,
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
all_root_files = [t[-1] for t in os.walk(tmpdirname)]
|
||||
files = [item for sublist in all_root_files for item in sublist]
|
||||
|
||||
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
|
||||
unet_files = os.listdir(os.path.join(tmpdirname, "unet"))
|
||||
|
||||
# Some of the downloaded files should be a non-variant file, check:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# only unet has "no_ema" variant
|
||||
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
|
||||
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
|
||||
# vae, safety_checker and text_encoder should have no variant
|
||||
assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
# Some of the downloaded files should be a non-variant file, check:
|
||||
# https://huggingface.co/hf-internal-testing/stable-diffusion-all-variants/tree/main/unet
|
||||
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
|
||||
# only unet has "no_ema" variant
|
||||
assert f"diffusion_pytorch_model.{variant}{this_format}" in unet_files
|
||||
assert len([f for f in files if f.endswith(f"{variant}{this_format}")]) == 1
|
||||
# vae, safety_checker and text_encoder should have no variant
|
||||
assert (
|
||||
sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3
|
||||
)
|
||||
assert not any(f.endswith(other_format) for f in files)
|
||||
|
||||
def test_download_variants_with_sharded_checkpoints(self):
|
||||
# Here we test for downloading of "variant" files belonging to the `unet` and
|
||||
@@ -588,20 +600,17 @@ class DownloadTests(unittest.TestCase):
|
||||
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
|
||||
deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant"
|
||||
|
||||
for is_local in [True, False]:
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
local_repo_id = repo_id
|
||||
if is_local:
|
||||
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname)
|
||||
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
local_repo_id,
|
||||
safety_checker=None,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
local_repo_id,
|
||||
safety_checker=None,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
)
|
||||
assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs"
|
||||
|
||||
def test_download_safetensors_only_variant_exists_for_model(self):
|
||||
variant = None
|
||||
@@ -616,7 +625,7 @@ class DownloadTests(unittest.TestCase):
|
||||
variant=variant,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
assert "Error no file name" in str(error_context.exception)
|
||||
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
|
||||
|
||||
# text encoder has fp16 variants so we can load it
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -675,7 +684,7 @@ class DownloadTests(unittest.TestCase):
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
|
||||
assert "Error no file name" in str(error_context.exception)
|
||||
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
|
||||
|
||||
def test_download_bin_variant_does_not_exist_for_model(self):
|
||||
variant = "no_ema"
|
||||
|
||||
Reference in New Issue
Block a user