diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index f0f91a3a86..8aa2099154 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -785,8 +785,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) torch.manual_seed(0) new_model = self.model_class(**init_dict) - new_model.to(torch_device) new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin") + new_model.to(torch_device) with torch.no_grad(): new_sample = new_model(**inputs_dict).sample