From da801e97ba03358b51a2b248df6f9c32d9ddf2ec Mon Sep 17 00:00:00 2001 From: DN6 Date: Tue, 13 Jan 2026 15:15:13 +0530 Subject: [PATCH] update --- tests/models/testing_utils/common.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/testing_utils/common.py b/tests/models/testing_utils/common.py index b46530b4af..b85f8890f5 100644 --- a/tests/models/testing_utils/common.py +++ b/tests/models/testing_utils/common.py @@ -462,13 +462,15 @@ class ModelTesterMixin: def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype): model = self.model_class(**self.get_init_dict()) model.to(torch_device) - fp32_modules = model._keep_in_fp32_modules + fp32_modules = model._keep_in_fp32_modules or [] model.to(dtype).save_pretrained(tmp_path) model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device) for name, param in model_loaded.named_parameters(): - if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules): + if fp32_modules and any( + module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules + ): assert param.data.dtype == torch.float32 else: assert param.data.dtype == dtype