diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 7bc351ad76..4b69a03fc5 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -189,6 +189,10 @@ class StableUnCLIPPipelineIntegrationTests(unittest.TestCase): pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # stable unclip will oom when integration tests are run on a V100, + # so turn on memory savings + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) output = pipe("anime turle", generator=generator, output_type="np") diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index adbf3b2727..539e0dcebe 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -185,6 +185,10 @@ class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase): ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # stable unclip will oom when integration tests are run on a V100, + # so turn on memory savings + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) output = pipe("anime turle", image=input_image, generator=generator, output_type="np") @@ -209,6 +213,10 @@ class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase): ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # stable unclip will oom when integration tests are run on a V100, + # so turn on memory savings + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) output = pipe("anime turle", image=input_image, generator=generator, output_type="np")