mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -81,6 +81,7 @@ class ModuleGroup:
|
||||
self._is_offloaded_to_disk = False
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
|
||||
all_tensors = []
|
||||
param_names = []
|
||||
for module in self.modules:
|
||||
@@ -96,9 +97,6 @@ class ModuleGroup:
|
||||
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
|
||||
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
|
||||
|
||||
group_id_key = "_".join(sorted(param_names))
|
||||
self._disk_offload_group_id = hashlib.md5(group_id_key.encode()).hexdigest()[:8]
|
||||
|
||||
self.cpu_param_dict = {}
|
||||
else:
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
@@ -106,13 +104,6 @@ class ModuleGroup:
|
||||
if self.stream is None and self.record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `stream` is None.")
|
||||
|
||||
@property
|
||||
def _disk_offload_file_path(self):
|
||||
if self.offload_to_disk_path:
|
||||
return os.path.join(self.offload_to_disk_path, f"group_{self._disk_offload_group_id}.safetensors")
|
||||
|
||||
return None
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
@@ -173,7 +164,7 @@ class ModuleGroup:
|
||||
|
||||
def _onload_from_disk(self, current_stream):
|
||||
if self.stream is not None:
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device="cpu")
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
|
||||
@@ -188,7 +179,7 @@ class ModuleGroup:
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device=onload_device)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
|
||||
@@ -231,10 +222,10 @@ class ModuleGroup:
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
# we perform a write.
|
||||
# Check if the file has been saved in this session or if it already exists on disk.
|
||||
if not self._is_offloaded_to_disk and not os.path.exists(self._disk_offload_file_path):
|
||||
os.makedirs(os.path.dirname(self._disk_offload_file_path), exist_ok=True)
|
||||
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||||
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
||||
safetensors.torch.save_file(tensors_to_save, self._disk_offload_file_path)
|
||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||||
|
||||
# The group is now considered offloaded to disk for the rest of the session.
|
||||
self._is_offloaded_to_disk = True
|
||||
|
||||
Reference in New Issue
Block a user