1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-06-19 18:26:10 +05:30
parent 90e546ada1
commit ace698aa96

View File

@@ -135,9 +135,32 @@ class ModuleGroup:
finally:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor=None, current_stream=None):
if source_tensor is None:
source_tensor = tensor
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream and current_stream is not None:
tensor.data.record_stream(current_stream)
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
for group_module in self.modules:
for param in group_module.parameters():
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
for buffer in group_module.buffers():
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
for param in self.parameters:
source = pinned_memory[param] if pinned_memory else param.data
self._transfer_tensor_to_device(param, source, current_stream)
for buffer in self.buffers:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, current_stream)
@torch.compiler.disable()
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
@@ -146,96 +169,65 @@ class ModuleGroup:
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.offload_to_disk_path:
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
tensor_obj.data.record_stream(current_stream)
else:
# Load directly to the target device (synchronous)
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
with context:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
for group_module in self.modules:
for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
param.data.record_stream(current_stream)
for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
if self.offload_to_disk_path:
self._onload_from_disk(current_stream)
else:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
self._onload_from_memory(current_stream)
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
def _onload_from_disk(self, current_stream):
if self.stream is not None:
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
buffer.data.record_stream(current_stream)
for key, tensor_obj in self.key_to_tensor.items():
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
with self._pinned_memory_tensors() as pinned_memory:
for key, tensor_obj in self.key_to_tensor.items():
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
self.cpu_param_dict.clear()
else:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
def _onload_from_memory(self, current_stream):
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
self._process_tensors_from_modules(pinned_memory, current_stream)
else:
self._process_tensors_from_modules(None, current_stream)
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk_path:
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
def _offload_to_disk(self):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
return
# We do this to free up the RAM which is still holding the up tensor data.
for tensor_obj in self.tensor_to_key.keys():
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
@torch.compiler.disable()
def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
@@ -260,6 +252,14 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
@torch.compiler.disable()
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.offload_to_disk_path:
self._offload_to_disk()
else:
self._offload_to_memory()
class GroupOffloadingHook(ModelHook):
r"""