1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix slow tsets (#3066)

* fix slow tsets

* make style
This commit is contained in:
Patrick von Platen
2023-04-12 10:23:52 +02:00
committed by GitHub
parent a89a14fa7a
commit 0c72006e3a

View File

@@ -411,7 +411,9 @@ class Transformer2DModelTests(unittest.TestCase):
assert attention_scores.shape == (1, 64, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598])
expected_slice = torch.tensor(
[0.0143, -0.6909, -2.1547, -1.8893, 1.4097, 0.1359, -0.2521, -1.3359, 0.2598], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_timestep(self):
@@ -442,9 +444,11 @@ class Transformer2DModelTests(unittest.TestCase):
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703])
expected_slice = torch.tensor(
[-0.3923, -1.0923, -1.7144, -1.5570, 1.4154, 0.1738, -0.1157, -1.2998, -0.1703], device=torch_device
)
expected_slice_2 = torch.tensor(
[-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348]
[-0.4311, -1.1376, -1.7732, -1.5997, 1.3450, 0.0964, -0.1569, -1.3590, -0.2348], device=torch_device
)
assert torch.allclose(output_slice_1.flatten(), expected_slice, atol=1e-3)