From 16c955c5fdff7dc427488eb691411bcb2bedd68d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 6 Jun 2025 11:34:44 +0530 Subject: [PATCH] [tests] add test for torch.compile + group offloading (#11670) * add a test for group offloading + compilation. * tests --- tests/models/test_modeling_common.py | 34 ++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 453690c1c9..5087bd0094 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1744,6 +1744,10 @@ class ModelPushToHubTester(unittest.TestCase): delete_repo(self.repo_id, token=TOKEN) +@require_torch_gpu +@require_torch_2 +@is_torch_compile +@slow class TorchCompileTesterMixin: def setUp(self): # clean up the VRAM before each test @@ -1759,12 +1763,7 @@ class TorchCompileTesterMixin: gc.collect() backend_empty_cache(torch_device) - @require_torch_gpu - @require_torch_2 - @is_torch_compile - @slow def test_torch_compile_recompilation_and_graph_break(self): - torch.compiler.reset() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict).to(torch_device) @@ -1778,6 +1777,31 @@ class TorchCompileTesterMixin: _ = model(**inputs_dict) _ = model(**inputs_dict) + def test_compile_with_group_offloading(self): + torch._dynamo.config.cache_size_limit = 10000 + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + if not getattr(model, "_supports_group_offloading", True): + return + + model.eval() + # TODO: Can test for other group offloading kwargs later if needed. + group_offload_kwargs = { + "onload_device": "cuda", + "offload_device": "cpu", + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + "non_blocking": True, + } + model.enable_group_offload(**group_offload_kwargs) + model.compile() + with torch.no_grad(): + _ = model(**inputs_dict) + _ = model(**inputs_dict) + @slow @require_torch_2