1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix unique memory address when doing group-offloading with disk (#11767)

* fix memory address problem

* add more tests

* updates

* updates

* update

* _group_id = group_id

* update

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* update

* update

* update

* fix

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
Sayak Paul
2025-07-09 21:29:34 +05:30
committed by GitHub
parent db715e2c8c
commit 2d3d376bc0
3 changed files with 167 additions and 9 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
@@ -37,7 +38,7 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
_GROUP_ID_LAZY_LEAF = "lazy_leafs"
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -82,6 +83,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -100,7 +102,10 @@ class ModuleGroup:
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self.group_id = group_id if group_id is not None else str(id(self))
short_hash = _compute_group_hash(self.group_id)
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
all_tensors = []
for module in self.modules:
@@ -609,6 +614,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
@@ -621,6 +627,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -655,6 +662,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
@@ -686,6 +694,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name)
@@ -732,6 +741,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
_apply_group_offloading_hook(parent_module, group, None, config=config)
@@ -753,6 +763,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
record_stream=False,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=_GROUP_ID_LAZY_LEAF,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
@@ -873,6 +884,12 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
raise ValueError("Group offloading is not enabled for the provided module.")
def _compute_group_hash(group_id):
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
# first 16 characters for a reasonably short but unique name
return hashed_id[:16]
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been

View File

@@ -1,4 +1,5 @@
import functools
import glob
import importlib
import importlib.metadata
import inspect
@@ -18,7 +19,7 @@ from collections import UserDict
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import PIL.Image
@@ -1392,6 +1393,103 @@ if TYPE_CHECKING:
else:
DevicePropertiesUserDict = UserDict
if is_torch_available():
from diffusers.hooks.group_offloading import (
_GROUP_ID_LAZY_LEAF,
_SUPPORTED_PYTORCH_LAYERS,
_compute_group_hash,
_find_parent_module_in_module_dict,
_gather_buffers_with_no_group_offloading_parent,
_gather_parameters_with_no_group_offloading_parent,
)
def _get_expected_safetensors_files(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
) -> Set[str]:
expected_files = set()
def get_hashed_filename(group_id: str) -> str:
short_hash = _compute_group_hash(group_id)
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
# Handle groups of ModuleList and Sequential blocks
unmatched_modules = []
for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
unmatched_modules.append(module)
continue
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
if not current_modules:
continue
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
expected_files.add(get_hashed_filename(group_id))
# Handle the group for unmatched top-level modules and parameters
for module in unmatched_modules:
expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
elif offload_type == "leaf_level":
# Handle leaf-level module groups
for name, submodule in module.named_modules():
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
# These groups will always have parameters, so a file is expected
expected_files.add(get_hashed_filename(name))
# Handle groups for non-leaf parameters/buffers
modules_with_group_offloading = {
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
}
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
all_orphans = parameters + buffers
if all_orphans:
parent_to_tensors = {}
module_dict = dict(module.named_modules())
for tensor_name, _ in all_orphans:
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
if parent_name not in parent_to_tensors:
parent_to_tensors[parent_name] = []
parent_to_tensors[parent_name].append(tensor_name)
for parent_name in parent_to_tensors:
# A file is expected for each parent that gathers orphaned tensors
expected_files.add(get_hashed_filename(parent_name))
expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
return expected_files
def _check_safetensors_serialization(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: Optional[int] = None,
) -> bool:
if not os.path.isdir(offload_to_disk_path):
return False, None, None
expected_files = _get_expected_safetensors_files(
module, offload_to_disk_path, offload_type, num_blocks_per_group
)
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
missing_files = expected_files - actual_files
extra_files = actual_files - expected_files
is_correct = not missing_files and not extra_files
return is_correct, extra_files, missing_files
class Expectations(DevicePropertiesUserDict):
def get_expectation(self) -> Any:

View File

@@ -61,6 +61,7 @@ from diffusers.utils import (
from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import (
CaptureLogger,
_check_safetensors_serialization,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
@@ -1702,18 +1703,43 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@parameterized.expand([("block_level", False), ("leaf_level", True)])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type):
@torch.inference_mode()
def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch.manual_seed(0)
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
@@ -1724,8 +1750,25 @@ class ModelTesterMixin:
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
self.assertTrue(len(has_safetensors) > 0, "No safetensors found in the offload directory.")
_ = model(**inputs_dict)[0]
self.assertTrue(has_safetensors, "No safetensors found in the directory.")
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args: