diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index cb9f1cd789..5b4c22a5f4 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -65,7 +65,7 @@ class ModuleGroup: low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: self.modules = modules self.offload_device = offload_device @@ -86,7 +86,7 @@ class ModuleGroup: if self.offload_to_disk_path: self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors" - self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.{ext}") + self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}{ext}") all_tensors = [] for module in self.modules: @@ -160,7 +160,10 @@ class ModuleGroup: with context: if self.stream is not None: # Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute - loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu") + if self._enable_deepnvme_disk_offloading: + loaded_cpu_tensors = torch.load(self.param_file_path, weights_only=True, map_location="cpu") + else: + loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu") for key, tensor_obj in self.key_to_tensor.items(): pinned_tensor = loaded_cpu_tensors[key].pin_memory() tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) @@ -171,7 +174,12 @@ class ModuleGroup: onload_device = ( self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device ) - loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device) + if self._enable_deepnvme_disk_offloading: + loaded_tensors = torch.load( + self.param_file_path, weights_only=True, map_location=onload_device + ) + else: + loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): tensor_obj.data = loaded_tensors[key] return @@ -232,10 +240,10 @@ class ModuleGroup: tensors_to_save = { key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() } - if not self._enable_deepnvme_disk_offloading: - safetensors.torch.save_file(tensors_to_save, self.param_file_path) - else: + if self._enable_deepnvme_disk_offloading: _fast_aio_save(tensors_to_save, self.param_file_path) + else: + safetensors.torch.save_file(tensors_to_save, self.param_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True @@ -435,7 +443,7 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -541,7 +549,7 @@ def apply_group_offloading( stream=stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) elif offload_type == "leaf_level": _apply_group_offloading_leaf_level( @@ -553,7 +561,7 @@ def apply_group_offloading( stream=stream, record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) else: raise ValueError(f"Unsupported offload_type: {offload_type}") @@ -569,7 +577,7 @@ def _apply_group_offloading_block_level( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to @@ -630,7 +638,7 @@ def _apply_group_offloading_block_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) matched_module_groups.append(group) for j in range(i, i + len(current_modules)): @@ -682,7 +690,7 @@ def _apply_group_offloading_leaf_level( record_stream: Optional[bool] = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory @@ -733,7 +741,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) _apply_group_offloading_hook(submodule, group, None) modules_with_group_offloading.add(name) @@ -781,7 +789,7 @@ def _apply_group_offloading_leaf_level( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) _apply_group_offloading_hook(parent_module, group, None) @@ -803,7 +811,7 @@ def _apply_group_offloading_leaf_level( record_stream=False, low_cpu_mem_usage=low_cpu_mem_usage, onload_self=True, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) _apply_lazy_group_offloading_hook(module, unmatched_group, None) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index bf2bb3596d..1d0c681594 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -549,7 +549,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, - _enable_deepnvme_disk_offloading: Optional[bool] = False + _enable_deepnvme_disk_offloading: Optional[bool] = False, ) -> None: r""" Activates group offloading for the current model. @@ -600,7 +600,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, - _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading + _enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading, ) def save_pretrained( diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 4f494ef8fe..aea77b5405 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -420,3 +420,4 @@ def _fast_aio_save(buffer, file, single_io_buffer=False): ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config) _nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False) + ds_fast_writer.close()