From e3bf932404fb47468eabac162a3a285a79f8cb55 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 21 Jun 2022 12:02:21 +0200 Subject: [PATCH] don't hardcode device in tests --- tests/test_modeling_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4b2dee698f..a58759b297 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -262,8 +262,6 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): sizes = (32, 32) low_res_size = (4, 4) - torch_device = "cpu" - noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device) low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device) @@ -355,8 +353,6 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): transformer_dim = 32 seq_len = 16 - torch_device = "cpu" - noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device)