mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] add test for torch.compile + group offloading (#11670)
* add a test for group offloading + compilation. * tests
This commit is contained in:
@@ -1744,6 +1744,10 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
class TorchCompileTesterMixin:
|
||||
def setUp(self):
|
||||
# clean up the VRAM before each test
|
||||
@@ -1759,12 +1763,7 @@ class TorchCompileTesterMixin:
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_2
|
||||
@is_torch_compile
|
||||
@slow
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
torch.compiler.reset()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -1778,6 +1777,31 @@ class TorchCompileTesterMixin:
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
if not getattr(model, "_supports_group_offloading", True):
|
||||
return
|
||||
|
||||
model.eval()
|
||||
# TODO: Can test for other group offloading kwargs later if needed.
|
||||
group_offload_kwargs = {
|
||||
"onload_device": "cuda",
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_2
|
||||
|
||||
Reference in New Issue
Block a user