From 8029cd7ef09a34208c1941c1cb299f59ab47b2c3 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 12 Jun 2025 11:08:19 +0530 Subject: [PATCH] add test and clarify. --- src/diffusers/hooks/group_offloading.py | 1 + tests/models/test_modeling_common.py | 30 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index d00a007137..1ea60c3f33 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -231,6 +231,7 @@ class ModuleGroup: # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True + # We do this to free up the RAM which is still holding the up tensor data. for tensor_obj in self.tensor_to_key.keys(): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) return diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5087bd0094..75e04b0a50 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -15,6 +15,7 @@ import copy import gc +import glob import inspect import json import os @@ -1608,6 +1609,35 @@ class ModelTesterMixin: model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) _ = model(**inputs_dict)[0] + @parameterized.expand([(False, "block_level"), (True, "leaf_level")]) + @require_torch_accelerator + @torch.no_grad() + def test_group_offloading_with_disk(self, record_stream, offload_type): + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + if not getattr(model, "_supports_group_offloading", True): + return + + torch.manual_seed(0) + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.eval() + additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1} + with tempfile.TemporaryDirectory() as tmpdir: + model.enable_group_offload( + torch_device, + offload_type=offload_type, + offload_to_disk_path=tmpdir, + use_stream=True, + record_stream=record_stream, + **additional_kwargs, + ) + has_safetensors = glob.glob(f"{tmpdir}/*.safetensors") + assert has_safetensors + _ = model(**inputs_dict)[0] + def test_auto_model(self, expected_max_diff=5e-5): if self.forward_requires_fresh_args: model = self.model_class(**self.init_dict)