From 1e5eaca754bce676ce9142cab7ccaaee78df4696 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Sat, 18 Feb 2023 19:24:52 -0800 Subject: [PATCH] stable unclip integration tests turn on memory saving (#2408) * stable unclip integration tests turn on memory saving * add note on turning on memory saving --- tests/pipelines/stable_unclip/test_stable_unclip.py | 4 ++++ .../pipelines/stable_unclip/test_stable_unclip_img2img.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 7bc351ad76..4b69a03fc5 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -189,6 +189,10 @@ class StableUnCLIPPipelineIntegrationTests(unittest.TestCase): pipe = StableUnCLIPPipeline.from_pretrained("fusing/stable-unclip-2-1-l", torch_dtype=torch.float16) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # stable unclip will oom when integration tests are run on a V100, + # so turn on memory savings + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) output = pipe("anime turle", generator=generator, output_type="np") diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index adbf3b2727..539e0dcebe 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -185,6 +185,10 @@ class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase): ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # stable unclip will oom when integration tests are run on a V100, + # so turn on memory savings + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) output = pipe("anime turle", image=input_image, generator=generator, output_type="np") @@ -209,6 +213,10 @@ class StableUnCLIPImg2ImgPipelineIntegrationTests(unittest.TestCase): ) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) + # stable unclip will oom when integration tests are run on a V100, + # so turn on memory savings + pipe.enable_attention_slicing() + pipe.enable_sequential_cpu_offload() generator = torch.Generator(device="cpu").manual_seed(0) output = pipe("anime turle", image=input_image, generator=generator, output_type="np")