mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update
This commit is contained in:
@@ -462,13 +462,15 @@ class ModelTesterMixin:
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
fp32_modules = model._keep_in_fp32_modules or []
|
||||
|
||||
model.to(dtype).save_pretrained(tmp_path)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
if fp32_modules and any(
|
||||
module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules
|
||||
):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == dtype
|
||||
|
||||
Reference in New Issue
Block a user