1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add test and clarify.

This commit is contained in:
sayakpaul
2025-06-12 11:08:19 +05:30
parent 4e4842fb0b
commit 8029cd7ef0
2 changed files with 31 additions and 0 deletions

View File

@@ -231,6 +231,7 @@ class ModuleGroup:
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
return

View File

@@ -15,6 +15,7 @@
import copy
import gc
import glob
import inspect
import json
import os
@@ -1608,6 +1609,35 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors
_ = model(**inputs_dict)[0]
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)