mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix test_output_pretrained for GLIDESuperResUNetModel
This commit is contained in:
@@ -320,17 +320,14 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
|
||||
@unittest.skip("GLIDESuperResUNetModel always outputs zero")
|
||||
def test_output_pretrained(self):
|
||||
model = GLIDESuperResUNetModel.from_pretrained("fusing/glide-super-res-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
noise = torch.randn(1, 3, 32, 32)
|
||||
noise = torch.randn(1, 3, 64, 64)
|
||||
low_res = torch.randn(1, 3, 4, 4)
|
||||
time_step = torch.tensor([42] * noise.shape[0])
|
||||
|
||||
@@ -340,9 +337,8 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
output, _ = torch.split(output, 3, dim=1)
|
||||
output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([ 0.2891, -0.1899, 0.2595, -0.6214, 0.0968, -0.2622, 0.4688, 0.1311, 0.0053])
|
||||
expected_output_slice = torch.tensor([-22.8782, -23.2652, -15.3966, -22.8034, -23.3159, -15.5640, -15.3970, -15.4614, - 10.4370])
|
||||
# fmt: on
|
||||
print(output_slice)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user