1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix OOM for test_vae_tiling (#7510)

use float16 and add torch.no_grad()
This commit is contained in:
YiYi Xu
2024-03-28 16:52:39 -10:00
committed by GitHub
parent e49c04d5d6
commit 34c90dbb31
2 changed files with 12 additions and 8 deletions

View File

@@ -1118,8 +1118,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
assert torch_all_close(actual_output, expected_output, atol=5e-3)
def test_vae_tiling(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1143,6 +1145,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)
with torch.no_grad():
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)

View File

@@ -124,9 +124,10 @@ class SDFunctionTesterMixin:
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
with torch.no_grad():
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
def test_freeu_enabled(self):
components = self.get_dummy_components()