diff --git a/tests/test_ema.py b/tests/test_ema.py index 9f99457080..c532681ef0 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -33,11 +33,9 @@ class EMAModelTests(unittest.TestCase): generator = torch.manual_seed(0) def get_models(self, decay=0.9999): - unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet", device=torch_device) - ema_unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") - ema_unet = EMAModel( - ema_unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=ema_unet.config - ) + unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") + unet = unet.to(torch_device) + ema_unet = EMAModel(unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=unet.config) return unet, ema_unet def get_dummy_inputs(self): @@ -149,6 +147,7 @@ class EMAModelTests(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdir: ema_unet.save_pretrained(tmpdir) loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) + loaded_unet = loaded_unet.to(unet.device) # Since no EMA step has been performed the outputs should match. output = unet(noisy_latents, timesteps, encoder_hidden_states).sample