diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 3b698624ff..5283639269 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -804,6 +804,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): gc.collect() torch.cuda.empty_cache() + @torch.no_grad() def test_encode_decode(self): vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update vae.to(torch_device)