1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
DN6
2026-01-13 15:15:13 +05:30
parent ce3097c65b
commit da801e97ba

View File

@@ -462,13 +462,15 @@ class ModelTesterMixin:
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules
fp32_modules = model._keep_in_fp32_modules or []
model.to(dtype).save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
if fp32_modules and any(
module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules
):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == dtype