mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix EMAModel test_from_pretrained (#10325)
This commit is contained in:
@@ -67,6 +67,7 @@ class EMAModelTests(unittest.TestCase):
|
||||
|
||||
# Load the EMA model from the saved directory
|
||||
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
|
||||
loaded_ema_unet.to(torch_device)
|
||||
|
||||
# Check that the shadow parameters of the loaded model match the original EMA model
|
||||
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
|
||||
@@ -221,6 +222,7 @@ class EMAModelTestsForeach(unittest.TestCase):
|
||||
|
||||
# Load the EMA model from the saved directory
|
||||
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
|
||||
loaded_ema_unet.to(torch_device)
|
||||
|
||||
# Check that the shadow parameters of the loaded model match the original EMA model
|
||||
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
|
||||
|
||||
Reference in New Issue
Block a user