From aa79d7da46ce0c2ae57a57a87c9d0b786cef370b Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 14 Jan 2025 09:54:06 +0530 Subject: [PATCH] Test sequential cpu offload for torchao quantization (#10506) test sequential cpu offload --- tests/quantization/torchao/test_torchao.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 3c3f13db9b..7d1503b91f 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -476,6 +476,18 @@ class TorchAoTest(unittest.TestCase): with self.assertRaises(ValueError): self.get_dummy_components(TorchAoConfig("int42")) + def test_sequential_cpu_offload(self): + r""" + A test that checks if inference runs as expected when sequential cpu offloading is enabled. + """ + quantization_config = TorchAoConfig("int8wo") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.enable_sequential_cpu_offload() + + inputs = self.get_dummy_inputs(torch_device) + _ = pipe(**inputs) + # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch