mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix OOM for test_vae_tiling (#7510)
use float16 and add torch.no_grad()
This commit is contained in:
@@ -1118,8 +1118,10 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_vae_tiling(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
@@ -1143,6 +1145,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
|
||||
# test that tiled decode works with various shapes
|
||||
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
|
||||
for shape in shapes:
|
||||
image = torch.zeros(shape, device=torch_device)
|
||||
pipe.vae.decode(image)
|
||||
with torch.no_grad():
|
||||
for shape in shapes:
|
||||
image = torch.zeros(shape, device=torch_device)
|
||||
pipe.vae.decode(image)
|
||||
|
||||
@@ -124,9 +124,10 @@ class SDFunctionTesterMixin:
|
||||
|
||||
# test that tiled decode works with various shapes
|
||||
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
|
||||
for shape in shapes:
|
||||
zeros = torch.zeros(shape).to(torch_device)
|
||||
pipe.vae.decode(zeros)
|
||||
with torch.no_grad():
|
||||
for shape in shapes:
|
||||
zeros = torch.zeros(shape).to(torch_device)
|
||||
pipe.vae.decode(zeros)
|
||||
|
||||
def test_freeu_enabled(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
Reference in New Issue
Block a user