From 7a001c3ee2d51ecc69987da050468a318414afa3 Mon Sep 17 00:00:00 2001 From: kaixuanliu Date: Thu, 13 Nov 2025 14:27:12 +0800 Subject: [PATCH] adjust unit tests for `test_save_load_float16` (#12500) * adjust unit tests for wan pipeline Signed-off-by: Liu, Kaixuan * update code Signed-off-by: Liu, Kaixuan * avoid adjusting common `get_dummy_components` API Signed-off-by: Liu, Kaixuan * use `form_pretrained` to `transformer` and `transformer_2` Signed-off-by: Liu, Kaixuan * update code Signed-off-by: Liu, Kaixuan * update Signed-off-by: Liu, Kaixuan --------- Signed-off-by: Liu, Kaixuan Co-authored-by: Sayak Paul Co-authored-by: Dhruv Nair --- tests/pipelines/test_pipelines_common.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2af4ad0314..e2bbce7b0e 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1422,7 +1422,18 @@ class PipelineTesterMixin: def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): - if hasattr(module, "half"): + # Account for components with _keep_in_fp32_modules + if hasattr(module, "_keep_in_fp32_modules") and module._keep_in_fp32_modules is not None: + for name, param in module.named_parameters(): + if any( + module_to_keep_in_fp32 in name.split(".") + for module_to_keep_in_fp32 in module._keep_in_fp32_modules + ): + param.data = param.data.to(torch_device).to(torch.float32) + else: + param.data = param.data.to(torch_device).to(torch.float16) + + elif hasattr(module, "half"): components[name] = module.to(torch_device).half() pipe = self.pipeline_class(**components)