mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -119,6 +119,24 @@ class SingleFileModelTesterMixin:
|
||||
f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
|
||||
)
|
||||
|
||||
def test_checkpoint_altered_keys_loading(self):
|
||||
# Test loading with checkpoints that have altered keys
|
||||
if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if hasattr(self, "torch_dtype") and self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
|
||||
class SDSingleFileTesterMixin:
|
||||
single_file_kwargs = {}
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import (
|
||||
@@ -21,9 +20,7 @@ from diffusers import (
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
@@ -40,12 +37,3 @@ class Lumina2Transformer2DModelSingleFileTests(SingleFileModelTesterMixin, unitt
|
||||
|
||||
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
|
||||
subfolder = "transformer"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -37,18 +37,8 @@ class FluxTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest
|
||||
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
subfolder = "transformer"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_device_map_cuda(self):
|
||||
backend_empty_cache(torch_device)
|
||||
model = self.model_class.from_single_file(self.ckpt_path, device_map="cuda")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import (
|
||||
@@ -6,9 +5,7 @@ from diffusers import (
|
||||
)
|
||||
|
||||
from ..testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from .single_file_testing_utils import SingleFileModelTesterMixin
|
||||
|
||||
@@ -27,12 +24,3 @@ class SanaTransformer2DModelSingleFileTests(SingleFileModelTesterMixin, unittest
|
||||
|
||||
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
|
||||
subfolder = "transformer"
|
||||
|
||||
def test_checkpoint_loading(self):
|
||||
for ckpt_path in self.alternate_keys_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
model = self.model_class.from_single_file(ckpt_path)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user