1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-06-19 18:41:46 +05:30
parent bcb71c9c8b
commit 33f30ef86e
3 changed files with 27 additions and 18 deletions

View File

@@ -65,7 +65,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -86,7 +86,7 @@ class ModuleGroup:
if self.offload_to_disk_path:
self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading
ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors"
self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.{ext}")
self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}{ext}")
all_tensors = []
for module in self.modules:
@@ -160,7 +160,10 @@ class ModuleGroup:
with context:
if self.stream is not None:
# Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu")
if self._enable_deepnvme_disk_offloading:
loaded_cpu_tensors = torch.load(self.param_file_path, weights_only=True, map_location="cpu")
else:
loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
@@ -171,7 +174,12 @@ class ModuleGroup:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device)
if self._enable_deepnvme_disk_offloading:
loaded_tensors = torch.load(
self.param_file_path, weights_only=True, map_location=onload_device
)
else:
loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
@@ -232,10 +240,10 @@ class ModuleGroup:
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
if not self._enable_deepnvme_disk_offloading:
safetensors.torch.save_file(tensors_to_save, self.param_file_path)
else:
if self._enable_deepnvme_disk_offloading:
_fast_aio_save(tensors_to_save, self.param_file_path)
else:
safetensors.torch.save_file(tensors_to_save, self.param_file_path)
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
@@ -435,7 +443,7 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -541,7 +549,7 @@ def apply_group_offloading(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
@@ -553,7 +561,7 @@ def apply_group_offloading(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -569,7 +577,7 @@ def _apply_group_offloading_block_level(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -630,7 +638,7 @@ def _apply_group_offloading_block_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -682,7 +690,7 @@ def _apply_group_offloading_leaf_level(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -733,7 +741,7 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_group_offloading_hook(submodule, group, None)
modules_with_group_offloading.add(name)
@@ -781,7 +789,7 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -803,7 +811,7 @@ def _apply_group_offloading_leaf_level(
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)

View File

@@ -549,7 +549,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
Activates group offloading for the current model.
@@ -600,7 +600,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
def save_pretrained(

View File

@@ -420,3 +420,4 @@ def _fast_aio_save(buffer, file, single_io_buffer=False):
ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config)
_nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False)
ds_fast_writer.close()