From f64d52dbca93051a7652db7aa241964235a71035 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 12 Sep 2023 21:20:47 +0530 Subject: [PATCH] fix custom diffusion tests (#4996) --- tests/models/test_models_unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index f0f91a3a86..8aa2099154 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -785,8 +785,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) - new_model.to(torch_device) new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") + new_model.to(torch_device) with torch.no_grad(): new_sample = new_model(**inputs_dict).sample