From bf9a641f1a51368af5f3ae99cc460107d4fa2103 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 21 Dec 2024 14:10:44 +0000 Subject: [PATCH] Fix EMAModel test_from_pretrained (#10325) --- tests/others/test_ema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 3443e6366f..7cf8f30ecc 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): @@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):