mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
don't hardcode device in tests
This commit is contained in:
@@ -262,8 +262,6 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
sizes = (32, 32)
|
||||
low_res_size = (4, 4)
|
||||
|
||||
torch_device = "cpu"
|
||||
|
||||
noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device)
|
||||
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
@@ -355,8 +353,6 @@ class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
transformer_dim = 32
|
||||
seq_len = 16
|
||||
|
||||
torch_device = "cpu"
|
||||
|
||||
noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device)
|
||||
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
||||
Reference in New Issue
Block a user