diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 70a1c0d86f..d5ad5ee0d7 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -14,6 +14,7 @@ # limitations under the License. import copy +import gc import tempfile import unittest @@ -1024,6 +1025,11 @@ class StableDiffusionXLPipelineFastTests( @slow class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + def test_stable_diffusion_lcm(self): torch.manual_seed(0) unet = UNet2DConditionModel.from_pretrained(