1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[From pretrained] Don't download .safetensors files if safetensors is… (#2057)

* [From pretrained] Don't download .safetensors files if safetensors is not available

* tests

* tests

* up
This commit is contained in:
Patrick von Platen
2023-01-21 16:51:33 +02:00
committed by GitHub
parent aa265f74bd
commit 59b7339a84
2 changed files with 58 additions and 0 deletions

View File

@@ -514,7 +514,11 @@ class DiffusionPipeline(ConfigMixin):
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
else:
# as a safety mechanism we also don't download safetensors if
# not all safetensors files are there
ignore_patterns.append("*.safetensors")
else:
ignore_patterns.append("*.safetensors")
# download all allow_patterns
cached_folder = snapshot_download(

View File

@@ -299,6 +299,16 @@ class CustomPipelineTests(unittest.TestCase):
class PipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
import diffusers
diffusers.utils.import_utils._safetensors_available = True
def dummy_image(self):
batch_size = 1
num_channels = 3
@@ -567,6 +577,50 @@ class PipelineFastTests(unittest.TestCase):
assert pipeline.scheduler is not None
assert pipeline.feature_extractor is not None
def test_no_pytorch_download_when_doing_safetensors(self):
# by default we don't download
with tempfile.TemporaryDirectory() as tmpdirname:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
)
path = os.path.join(
tmpdirname,
"models--hf-internal-testing--diffusers-stable-diffusion-tiny-all",
"snapshots",
"07838d72e12f9bcec1375b0482b80c1d399be843",
"unet",
)
# safetensors exists
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors"))
# pytorch does not
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
def test_no_safetensors_download_when_doing_pytorch(self):
# mock diffusers safetensors not available
import diffusers
diffusers.utils.import_utils._safetensors_available = False
with tempfile.TemporaryDirectory() as tmpdirname:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
)
path = os.path.join(
tmpdirname,
"models--hf-internal-testing--diffusers-stable-diffusion-tiny-all",
"snapshots",
"07838d72e12f9bcec1375b0482b80c1d399be843",
"unet",
)
# safetensors does not exists
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors"))
# pytorch does
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
diffusers.utils.import_utils._safetensors_available = True
def test_optional_components(self):
unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")