mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -135,9 +135,7 @@ class ModuleGroup:
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor=None, current_stream=None):
|
||||
if source_tensor is None:
|
||||
source_tensor = tensor
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream and current_stream is not None:
|
||||
tensor.data.record_stream(current_stream)
|
||||
@@ -159,26 +157,6 @@ class ModuleGroup:
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, current_stream)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
if self.offload_to_disk_path:
|
||||
self._onload_from_disk(current_stream)
|
||||
else:
|
||||
self._onload_from_memory(current_stream)
|
||||
|
||||
def _onload_from_disk(self, current_stream):
|
||||
if self.stream is not None:
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
@@ -207,6 +185,26 @@ class ModuleGroup:
|
||||
else:
|
||||
self._process_tensors_from_modules(None, current_stream)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
if self.offload_to_disk_path:
|
||||
self._onload_from_disk(current_stream)
|
||||
else:
|
||||
self._onload_from_memory(current_stream)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def _offload_to_disk(self):
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
|
||||
Reference in New Issue
Block a user