diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index c2adb6ab1d..d00a007137 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -216,13 +216,20 @@ class ModuleGroup: def offload_(self): r"""Offloads the group of modules to the offload_device.""" if self.offload_to_disk_path: - if not self._is_offloaded_to_disk: + # TODO: we can potentially optimize this code path by checking if the _all_ the desired + # safetensor files exist on the disk and if so, skip this step entirely, reducing IO + # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not + # we perform a write. + # Check if the file has been saved in this session or if it already exists on disk. + if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) tensors_to_save = { key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items() } safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) - self._is_offloaded_to_disk = True + + # The group is now considered offloaded to disk for the rest of the session. + self._is_offloaded_to_disk = True for tensor_obj in self.tensor_to_key.keys(): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)