From 59b7339a849e900795cc33b92edcef3e34afa72c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 21 Jan 2023 16:51:33 +0200 Subject: [PATCH] =?UTF-8?q?[From=20pretrained]=20Don't=20download=20.safet?= =?UTF-8?q?ensors=20files=20if=20safetensors=20is=E2=80=A6=20(#2057)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [From pretrained] Don't download .safetensors files if safetensors is not available * tests * tests * up --- src/diffusers/pipelines/pipeline_utils.py | 4 ++ tests/test_pipelines.py | 54 +++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ea28ac875f..14f0454f6d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -514,7 +514,11 @@ class DiffusionPipeline(ConfigMixin): if is_safetensors_compatible(info): ignore_patterns.append("*.bin") else: + # as a safety mechanism we also don't download safetensors if + # not all safetensors files are there ignore_patterns.append("*.safetensors") + else: + ignore_patterns.append("*.safetensors") # download all allow_patterns cached_folder = snapshot_download( diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index cf8ef8d6da..9e44ee6e7c 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -299,6 +299,16 @@ class CustomPipelineTests(unittest.TestCase): class PipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + import diffusers + + diffusers.utils.import_utils._safetensors_available = True + def dummy_image(self): batch_size = 1 num_channels = 3 @@ -567,6 +577,50 @@ class PipelineFastTests(unittest.TestCase): assert pipeline.scheduler is not None assert pipeline.feature_extractor is not None + def test_no_pytorch_download_when_doing_safetensors(self): + # by default we don't download + with tempfile.TemporaryDirectory() as tmpdirname: + _ = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname + ) + + path = os.path.join( + tmpdirname, + "models--hf-internal-testing--diffusers-stable-diffusion-tiny-all", + "snapshots", + "07838d72e12f9bcec1375b0482b80c1d399be843", + "unet", + ) + # safetensors exists + assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors")) + # pytorch does not + assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin")) + + def test_no_safetensors_download_when_doing_pytorch(self): + # mock diffusers safetensors not available + import diffusers + + diffusers.utils.import_utils._safetensors_available = False + + with tempfile.TemporaryDirectory() as tmpdirname: + _ = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname + ) + + path = os.path.join( + tmpdirname, + "models--hf-internal-testing--diffusers-stable-diffusion-tiny-all", + "snapshots", + "07838d72e12f9bcec1375b0482b80c1d399be843", + "unet", + ) + # safetensors does not exists + assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors")) + # pytorch does + assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin")) + + diffusers.utils.import_utils._safetensors_available = True + def test_optional_components(self): unet = self.dummy_cond_unet() pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")