diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 3c3f13db9b..7d1503b91f 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -476,6 +476,18 @@ class TorchAoTest(unittest.TestCase): with self.assertRaises(ValueError): self.get_dummy_components(TorchAoConfig("int42")) + def test_sequential_cpu_offload(self): + r""" + A test that checks if inference runs as expected when sequential cpu offloading is enabled. + """ + quantization_config = TorchAoConfig("int8wo") + components = self.get_dummy_components(quantization_config) + pipe = FluxPipeline(**components) + pipe.enable_sequential_cpu_offload() + + inputs = self.get_dummy_inputs(torch_device) + _ = pipe(**inputs) + # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch