From 856dad57bb7a9ee13af4a08492e524b0a145a2c5 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Tue, 28 Feb 2023 20:37:42 -0800 Subject: [PATCH] is_safetensors_compatible refactor (#2499) * is_safetensors_compatible refactor * files list comma --- src/diffusers/pipelines/pipeline_utils.py | 54 ++++++--- tests/pipelines/test_pipeline_utils.py | 134 ++++++++++++++++++++++ 2 files changed, 175 insertions(+), 13 deletions(-) create mode 100644 tests/pipelines/test_pipeline_utils.py diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index c6e72cf3ef..f68c497103 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -129,21 +129,49 @@ class AudioPipelineOutput(BaseOutput): def is_safetensors_compatible(filenames, variant=None) -> bool: - pt_filenames = set(filename for filename in filenames if filename.endswith(".bin")) - is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames) + """ + Checking for safetensors compatibility: + - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch + files to know which safetensors files are needed. + - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. - for pt_filename in pt_filenames: - _variant = f".{variant}" if (variant is not None and variant in pt_filename) else "" - prefix, raw = os.path.split(pt_filename) - if raw == f"pytorch_model{_variant}.bin": - # transformers specific - sf_filename = os.path.join(prefix, f"model{_variant}.safetensors") + Converting default pytorch serialized filenames to safetensors serialized filenames: + - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" + - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" + extension is replaced with ".safetensors" + """ + pt_filenames = [] + + sf_filenames = set() + + for filename in filenames: + _, extension = os.path.splitext(filename) + + if extension == ".bin": + pt_filenames.append(filename) + elif extension == ".safetensors": + sf_filenames.add(filename) + + for filename in pt_filenames: + # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam' + path, filename = os.path.split(filename) + filename, extension = os.path.splitext(filename) + + if filename == "pytorch_model": + filename = "model" + elif filename == f"pytorch_model.{variant}": + filename = f"model.{variant}" else: - sf_filename = pt_filename[: -len(".bin")] + ".safetensors" - if is_safetensors_compatible and sf_filename not in filenames: - logger.warning(f"{sf_filename} not found") - is_safetensors_compatible = False - return is_safetensors_compatible + filename = filename + + expected_sf_filename = os.path.join(path, filename) + expected_sf_filename = f"{expected_sf_filename}.safetensors" + + if expected_sf_filename not in sf_filenames: + logger.warning(f"{expected_sf_filename} not found") + return False + + return True def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]: diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py new file mode 100644 index 0000000000..51d987d8bb --- /dev/null +++ b/tests/pipelines/test_pipeline_utils.py @@ -0,0 +1,134 @@ +import unittest + +from diffusers.pipelines.pipeline_utils import is_safetensors_compatible + + +class IsSafetensorsCompatibleTests(unittest.TestCase): + def test_all_is_compatible(self): + filenames = [ + "safety_checker/pytorch_model.bin", + "safety_checker/model.safetensors", + "vae/diffusion_pytorch_model.bin", + "vae/diffusion_pytorch_model.safetensors", + "text_encoder/pytorch_model.bin", + "text_encoder/model.safetensors", + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_diffusers_model_is_compatible(self): + filenames = [ + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_diffusers_model_is_not_compatible(self): + filenames = [ + "safety_checker/pytorch_model.bin", + "safety_checker/model.safetensors", + "vae/diffusion_pytorch_model.bin", + "vae/diffusion_pytorch_model.safetensors", + "text_encoder/pytorch_model.bin", + "text_encoder/model.safetensors", + "unet/diffusion_pytorch_model.bin", + # Removed: 'unet/diffusion_pytorch_model.safetensors', + ] + self.assertFalse(is_safetensors_compatible(filenames)) + + def test_transformer_model_is_compatible(self): + filenames = [ + "text_encoder/pytorch_model.bin", + "text_encoder/model.safetensors", + ] + self.assertTrue(is_safetensors_compatible(filenames)) + + def test_transformer_model_is_not_compatible(self): + filenames = [ + "safety_checker/pytorch_model.bin", + "safety_checker/model.safetensors", + "vae/diffusion_pytorch_model.bin", + "vae/diffusion_pytorch_model.safetensors", + "text_encoder/pytorch_model.bin", + # Removed: 'text_encoder/model.safetensors', + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.safetensors", + ] + self.assertFalse(is_safetensors_compatible(filenames)) + + def test_all_is_compatible_variant(self): + filenames = [ + "safety_checker/pytorch_model.fp16.bin", + "safety_checker/model.fp16.safetensors", + "vae/diffusion_pytorch_model.fp16.bin", + "vae/diffusion_pytorch_model.fp16.safetensors", + "text_encoder/pytorch_model.fp16.bin", + "text_encoder/model.fp16.safetensors", + "unet/diffusion_pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.safetensors", + ] + variant = "fp16" + self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + + def test_diffusers_model_is_compatible_variant(self): + filenames = [ + "unet/diffusion_pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.safetensors", + ] + variant = "fp16" + self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + + def test_diffusers_model_is_compatible_variant_partial(self): + # pass variant but use the non-variant filenames + filenames = [ + "unet/diffusion_pytorch_model.bin", + "unet/diffusion_pytorch_model.safetensors", + ] + variant = "fp16" + self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + + def test_diffusers_model_is_not_compatible_variant(self): + filenames = [ + "safety_checker/pytorch_model.fp16.bin", + "safety_checker/model.fp16.safetensors", + "vae/diffusion_pytorch_model.fp16.bin", + "vae/diffusion_pytorch_model.fp16.safetensors", + "text_encoder/pytorch_model.fp16.bin", + "text_encoder/model.fp16.safetensors", + "unet/diffusion_pytorch_model.fp16.bin", + # Removed: 'unet/diffusion_pytorch_model.fp16.safetensors', + ] + variant = "fp16" + self.assertFalse(is_safetensors_compatible(filenames, variant=variant)) + + def test_transformer_model_is_compatible_variant(self): + filenames = [ + "text_encoder/pytorch_model.fp16.bin", + "text_encoder/model.fp16.safetensors", + ] + variant = "fp16" + self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + + def test_transformer_model_is_compatible_variant_partial(self): + # pass variant but use the non-variant filenames + filenames = [ + "text_encoder/pytorch_model.bin", + "text_encoder/model.safetensors", + ] + variant = "fp16" + self.assertTrue(is_safetensors_compatible(filenames, variant=variant)) + + def test_transformer_model_is_not_compatible_variant(self): + filenames = [ + "safety_checker/pytorch_model.fp16.bin", + "safety_checker/model.fp16.safetensors", + "vae/diffusion_pytorch_model.fp16.bin", + "vae/diffusion_pytorch_model.fp16.safetensors", + "text_encoder/pytorch_model.fp16.bin", + # 'text_encoder/model.fp16.safetensors', + "unet/diffusion_pytorch_model.fp16.bin", + "unet/diffusion_pytorch_model.fp16.safetensors", + ] + variant = "fp16" + self.assertFalse(is_safetensors_compatible(filenames, variant=variant))