mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import gc
|
||||
import tempfile
|
||||
from io import BytesIO
|
||||
|
||||
@@ -9,7 +10,9 @@ from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_nam
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
@@ -47,6 +50,76 @@ def download_diffusers_config(repo_id, tmpdir):
|
||||
return path
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class SingleFileModelTesterMixin:
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {}
|
||||
single_file_kwargs = {}
|
||||
|
||||
if hasattr(self, "subfolder") and self.subfolder:
|
||||
pretrained_kwargs["subfolder"] = self.subfolder
|
||||
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict_single_file = model_single_file.state_dict()
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
|
||||
f"Parameter values differ for {key}: "
|
||||
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
|
||||
)
|
||||
|
||||
|
||||
class SDSingleFileTesterMixin:
|
||||
single_file_kwargs = {}
|
||||
|
||||
|
||||
@@ -23,16 +23,15 @@ from diffusers import (
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
|
||||
class Lumina2Transformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = Lumina2Transformer2DModel
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
|
||||
alternate_keys_ckpt_paths = [
|
||||
@@ -40,28 +39,7 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
|
||||
]
|
||||
|
||||
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
subfolder = "transformer"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -23,38 +22,24 @@ from diffusers import (
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
load_hf_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class AutoencoderDCSingleFileTests(unittest.TestCase):
|
||||
class AutoencoderDCSingleFileTests(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderDC
|
||||
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
|
||||
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
@@ -80,18 +65,6 @@ class AutoencoderDCSingleFileTests(unittest.TestCase):
|
||||
|
||||
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading"
|
||||
)
|
||||
|
||||
def test_single_file_in_type_variant_components(self):
|
||||
# `in` variant checkpoints require passing in a `config` parameter
|
||||
# in order to set the scaling factor correctly.
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -23,46 +22,19 @@ from diffusers import (
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class ControlNetModelSingleFileTests(unittest.TestCase):
|
||||
class ControlNetModelSingleFileTests(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = ControlNetModel
|
||||
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
|
||||
repo_id = "lllyasviel/control_v11p_sd15_canny"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
|
||||
def test_single_file_arguments(self):
|
||||
model_default = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
|
||||
@@ -23,43 +23,22 @@ from diffusers import (
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
class FluxTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
subfolder = "transformer"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -23,22 +22,18 @@ from diffusers import (
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
load_hf_numpy,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_accelerator,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
class AutoencoderKLSingleFileTests(unittest.TestCase):
|
||||
class AutoencoderKLSingleFileTests(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
|
||||
@@ -47,16 +42,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
@@ -84,18 +69,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
|
||||
|
||||
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading"
|
||||
)
|
||||
|
||||
def test_single_file_arguments(self):
|
||||
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import (
|
||||
@@ -21,42 +20,18 @@ from diffusers import (
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class AutoencoderKLWanSingleFileTests(unittest.TestCase):
|
||||
class AutoencoderKLWanSingleFileTests(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
|
||||
)
|
||||
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
subfolder = "vae"
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@@ -23,72 +22,26 @@ from diffusers import (
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_big_accelerator,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
|
||||
class WanTransformer3DModelText2VideoSingleFileTest(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
|
||||
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
subfolder = "transformer"
|
||||
|
||||
|
||||
@require_big_accelerator
|
||||
@require_torch_accelerator
|
||||
class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
|
||||
class WanTransformer3DModelImage2VideoSingleFileTest(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
|
||||
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
||||
torch_dtype = torch.float8_e4m3fn
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
subfolder = "transformer"
|
||||
|
||||
@@ -8,16 +8,15 @@ from diffusers import (
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
class SanaTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest.TestCase):
|
||||
model_class = SanaTransformer2DModel
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
|
||||
@@ -27,28 +26,7 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
|
||||
]
|
||||
|
||||
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_components(self):
|
||||
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between single file loading and pretrained loading"
|
||||
)
|
||||
subfolder = "transformer"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
|
||||
Reference in New Issue
Block a user