mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
135 lines
5.4 KiB
Python
135 lines
5.4 KiB
Python
import unittest
|
|
|
|
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
|
|
|
|
|
|
class IsSafetensorsCompatibleTests(unittest.TestCase):
|
|
def test_all_is_compatible(self):
|
|
filenames = [
|
|
"safety_checker/pytorch_model.bin",
|
|
"safety_checker/model.safetensors",
|
|
"vae/diffusion_pytorch_model.bin",
|
|
"vae/diffusion_pytorch_model.safetensors",
|
|
"text_encoder/pytorch_model.bin",
|
|
"text_encoder/model.safetensors",
|
|
"unet/diffusion_pytorch_model.bin",
|
|
"unet/diffusion_pytorch_model.safetensors",
|
|
]
|
|
self.assertTrue(is_safetensors_compatible(filenames))
|
|
|
|
def test_diffusers_model_is_compatible(self):
|
|
filenames = [
|
|
"unet/diffusion_pytorch_model.bin",
|
|
"unet/diffusion_pytorch_model.safetensors",
|
|
]
|
|
self.assertTrue(is_safetensors_compatible(filenames))
|
|
|
|
def test_diffusers_model_is_not_compatible(self):
|
|
filenames = [
|
|
"safety_checker/pytorch_model.bin",
|
|
"safety_checker/model.safetensors",
|
|
"vae/diffusion_pytorch_model.bin",
|
|
"vae/diffusion_pytorch_model.safetensors",
|
|
"text_encoder/pytorch_model.bin",
|
|
"text_encoder/model.safetensors",
|
|
"unet/diffusion_pytorch_model.bin",
|
|
# Removed: 'unet/diffusion_pytorch_model.safetensors',
|
|
]
|
|
self.assertFalse(is_safetensors_compatible(filenames))
|
|
|
|
def test_transformer_model_is_compatible(self):
|
|
filenames = [
|
|
"text_encoder/pytorch_model.bin",
|
|
"text_encoder/model.safetensors",
|
|
]
|
|
self.assertTrue(is_safetensors_compatible(filenames))
|
|
|
|
def test_transformer_model_is_not_compatible(self):
|
|
filenames = [
|
|
"safety_checker/pytorch_model.bin",
|
|
"safety_checker/model.safetensors",
|
|
"vae/diffusion_pytorch_model.bin",
|
|
"vae/diffusion_pytorch_model.safetensors",
|
|
"text_encoder/pytorch_model.bin",
|
|
# Removed: 'text_encoder/model.safetensors',
|
|
"unet/diffusion_pytorch_model.bin",
|
|
"unet/diffusion_pytorch_model.safetensors",
|
|
]
|
|
self.assertFalse(is_safetensors_compatible(filenames))
|
|
|
|
def test_all_is_compatible_variant(self):
|
|
filenames = [
|
|
"safety_checker/pytorch_model.fp16.bin",
|
|
"safety_checker/model.fp16.safetensors",
|
|
"vae/diffusion_pytorch_model.fp16.bin",
|
|
"vae/diffusion_pytorch_model.fp16.safetensors",
|
|
"text_encoder/pytorch_model.fp16.bin",
|
|
"text_encoder/model.fp16.safetensors",
|
|
"unet/diffusion_pytorch_model.fp16.bin",
|
|
"unet/diffusion_pytorch_model.fp16.safetensors",
|
|
]
|
|
variant = "fp16"
|
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
|
|
|
def test_diffusers_model_is_compatible_variant(self):
|
|
filenames = [
|
|
"unet/diffusion_pytorch_model.fp16.bin",
|
|
"unet/diffusion_pytorch_model.fp16.safetensors",
|
|
]
|
|
variant = "fp16"
|
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
|
|
|
def test_diffusers_model_is_compatible_variant_partial(self):
|
|
# pass variant but use the non-variant filenames
|
|
filenames = [
|
|
"unet/diffusion_pytorch_model.bin",
|
|
"unet/diffusion_pytorch_model.safetensors",
|
|
]
|
|
variant = "fp16"
|
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
|
|
|
def test_diffusers_model_is_not_compatible_variant(self):
|
|
filenames = [
|
|
"safety_checker/pytorch_model.fp16.bin",
|
|
"safety_checker/model.fp16.safetensors",
|
|
"vae/diffusion_pytorch_model.fp16.bin",
|
|
"vae/diffusion_pytorch_model.fp16.safetensors",
|
|
"text_encoder/pytorch_model.fp16.bin",
|
|
"text_encoder/model.fp16.safetensors",
|
|
"unet/diffusion_pytorch_model.fp16.bin",
|
|
# Removed: 'unet/diffusion_pytorch_model.fp16.safetensors',
|
|
]
|
|
variant = "fp16"
|
|
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
|
|
|
|
def test_transformer_model_is_compatible_variant(self):
|
|
filenames = [
|
|
"text_encoder/pytorch_model.fp16.bin",
|
|
"text_encoder/model.fp16.safetensors",
|
|
]
|
|
variant = "fp16"
|
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
|
|
|
def test_transformer_model_is_compatible_variant_partial(self):
|
|
# pass variant but use the non-variant filenames
|
|
filenames = [
|
|
"text_encoder/pytorch_model.bin",
|
|
"text_encoder/model.safetensors",
|
|
]
|
|
variant = "fp16"
|
|
self.assertTrue(is_safetensors_compatible(filenames, variant=variant))
|
|
|
|
def test_transformer_model_is_not_compatible_variant(self):
|
|
filenames = [
|
|
"safety_checker/pytorch_model.fp16.bin",
|
|
"safety_checker/model.fp16.safetensors",
|
|
"vae/diffusion_pytorch_model.fp16.bin",
|
|
"vae/diffusion_pytorch_model.fp16.safetensors",
|
|
"text_encoder/pytorch_model.fp16.bin",
|
|
# 'text_encoder/model.fp16.safetensors',
|
|
"unet/diffusion_pytorch_model.fp16.bin",
|
|
"unet/diffusion_pytorch_model.fp16.safetensors",
|
|
]
|
|
variant = "fp16"
|
|
self.assertFalse(is_safetensors_compatible(filenames, variant=variant))
|