diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 2121d1161f..8935d8084c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -81,6 +81,7 @@ class ModuleGroup: self._is_offloaded_to_disk = False if self.offload_to_disk_path: + self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors") all_tensors = [] param_names = [] for module in self.modules: @@ -96,9 +97,6 @@ class ModuleGroup: self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)} self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()} - group_id_key = "_".join(sorted(param_names)) - self._disk_offload_group_id = hashlib.md5(group_id_key.encode()).hexdigest()[:8] - self.cpu_param_dict = {} else: self.cpu_param_dict = self._init_cpu_param_dict() @@ -106,13 +104,6 @@ class ModuleGroup: if self.stream is None and self.record_stream: raise ValueError("`record_stream` cannot be True when `stream` is None.") - @property - def _disk_offload_file_path(self): - if self.offload_to_disk_path: - return os.path.join(self.offload_to_disk_path, f"group_{self._disk_offload_group_id}.safetensors") - - return None - def _init_cpu_param_dict(self): cpu_param_dict = {} if self.stream is None: @@ -173,7 +164,7 @@ class ModuleGroup: def _onload_from_disk(self, current_stream): if self.stream is not None: - loaded_cpu_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device="cpu") + loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") for key, tensor_obj in self.key_to_tensor.items(): self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key] @@ -188,7 +179,7 @@ class ModuleGroup: onload_device = ( self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device ) - loaded_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device=onload_device) + loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device) for key, tensor_obj in self.key_to_tensor.items(): tensor_obj.data = loaded_tensors[key] @@ -231,10 +222,10 @@ class ModuleGroup: # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not # we perform a write. # Check if the file has been saved in this session or if it already exists on disk. - if not self._is_offloaded_to_disk and not os.path.exists(self._disk_offload_file_path): - os.makedirs(os.path.dirname(self._disk_offload_file_path), exist_ok=True) + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): + os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} - safetensors.torch.save_file(tensors_to_save, self._disk_offload_file_path) + safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True