1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

don't hardcode device in tests

This commit is contained in:
patil-suraj
2022-06-21 12:02:21 +02:00
parent dc966cc447
commit e3bf932404

View File

@@ -262,8 +262,6 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
sizes = (32, 32)
low_res_size = (4, 4)
torch_device = "cpu"
noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
@@ -355,8 +353,6 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
transformer_dim = 32
seq_len = 16
torch_device = "cpu"
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)