1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-09-19 13:27:08 +05:30
parent 710e18b951
commit 1f6defd7d6
4 changed files with 18 additions and 34 deletions

View File

@@ -119,6 +119,24 @@ class SingleFileModelTesterMixin:
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
)
def test_checkpoint_altered_keys_loading(self):
# Test loading with checkpoints that have altered keys
if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
return
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if hasattr(self, "torch_dtype") and self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
del model
gc.collect()
backend_empty_cache(torch_device)
class SDSingleFileTesterMixin:
single_file_kwargs = {}

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from diffusers import (
@@ -21,9 +20,7 @@ from diffusers import (
)
from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
torch_device,
)
from .single_file_testing_utils import SingleFileModelTesterMixin
@@ -40,12 +37,3 @@ class Lumina2Transformer2DModelSingleFileTests(SingleFileModelTesterMixin, unitt
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
subfolder = "transformer"
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
backend_empty_cache(torch_device)

View File

@@ -37,18 +37,8 @@ class FluxTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
repo_id = "black-forest-labs/FLUX.1-dev"
subfolder = "transformer"
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
backend_empty_cache(torch_device)
def test_device_map_cuda(self):
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")

View File

@@ -1,4 +1,3 @@
import gc
import unittest
from diffusers import (
@@ -6,9 +5,7 @@ from diffusers import (
)
from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
torch_device,
)
from .single_file_testing_utils import SingleFileModelTesterMixin
@@ -27,12 +24,3 @@ class SanaTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
subfolder = "transformer"
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
backend_empty_cache(torch_device)