mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[From pretrained] Don't download .safetensors files if safetensors is… (#2057)
* [From pretrained] Don't download .safetensors files if safetensors is not available * tests * tests * up
This commit is contained in:
committed by
GitHub
parent
aa265f74bd
commit
59b7339a84
@@ -514,7 +514,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
if is_safetensors_compatible(info):
|
||||
ignore_patterns.append("*.bin")
|
||||
else:
|
||||
# as a safety mechanism we also don't download safetensors if
|
||||
# not all safetensors files are there
|
||||
ignore_patterns.append("*.safetensors")
|
||||
else:
|
||||
ignore_patterns.append("*.safetensors")
|
||||
|
||||
# download all allow_patterns
|
||||
cached_folder = snapshot_download(
|
||||
|
||||
@@ -299,6 +299,16 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
|
||||
|
||||
class PipelineFastTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
@@ -567,6 +577,50 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert pipeline.scheduler is not None
|
||||
assert pipeline.feature_extractor is not None
|
||||
|
||||
def test_no_pytorch_download_when_doing_safetensors(self):
|
||||
# by default we don't download
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
path = os.path.join(
|
||||
tmpdirname,
|
||||
"models--hf-internal-testing--diffusers-stable-diffusion-tiny-all",
|
||||
"snapshots",
|
||||
"07838d72e12f9bcec1375b0482b80c1d399be843",
|
||||
"unet",
|
||||
)
|
||||
# safetensors exists
|
||||
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors"))
|
||||
# pytorch does not
|
||||
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
||||
|
||||
def test_no_safetensors_download_when_doing_pytorch(self):
|
||||
# mock diffusers safetensors not available
|
||||
import diffusers
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = False
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", cache_dir=tmpdirname
|
||||
)
|
||||
|
||||
path = os.path.join(
|
||||
tmpdirname,
|
||||
"models--hf-internal-testing--diffusers-stable-diffusion-tiny-all",
|
||||
"snapshots",
|
||||
"07838d72e12f9bcec1375b0482b80c1d399be843",
|
||||
"unet",
|
||||
)
|
||||
# safetensors does not exists
|
||||
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.safetensors"))
|
||||
# pytorch does
|
||||
assert os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_optional_components(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
|
||||
Reference in New Issue
Block a user