1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix EMAModel test_from_pretrained (#10325)

This commit is contained in:
hlky
2024-12-21 14:10:44 +00:00
committed by GitHub
parent a756694bf0
commit bf9a641f1a

View File

@@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
@@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase):
# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
loaded_ema_unet.to(torch_device)
# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):