diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index ac6627ebb8..33a047cf0c 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -135,9 +135,7 @@ class ModuleGroup: finally: pinned_dict = None - def _transfer_tensor_to_device(self, tensor, source_tensor=None, current_stream=None): - if source_tensor is None: - source_tensor = tensor + def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None): tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) if self.record_stream and current_stream is not None: tensor.data.record_stream(current_stream) @@ -159,26 +157,6 @@ class ModuleGroup: source = pinned_memory[buffer] if pinned_memory else buffer.data self._transfer_tensor_to_device(buffer, source, current_stream) - @torch.compiler.disable() - def onload_(self): - torch_accelerator_module = ( - getattr(torch, torch.accelerator.current_accelerator().type) - if hasattr(torch, "accelerator") - else torch.cuda - ) - context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) - current_stream = torch_accelerator_module.current_stream() if self.record_stream else None - - if self.stream is not None: - # Wait for previous Host->Device transfer to complete - self.stream.synchronize() - - with context: - if self.offload_to_disk_path: - self._onload_from_disk(current_stream) - else: - self._onload_from_memory(current_stream) - def _onload_from_disk(self, current_stream): if self.stream is not None: loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu") @@ -207,6 +185,26 @@ class ModuleGroup: else: self._process_tensors_from_modules(None, current_stream) + @torch.compiler.disable() + def onload_(self): + torch_accelerator_module = ( + getattr(torch, torch.accelerator.current_accelerator().type) + if hasattr(torch, "accelerator") + else torch.cuda + ) + context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream) + current_stream = torch_accelerator_module.current_stream() if self.record_stream else None + + if self.stream is not None: + # Wait for previous Host->Device transfer to complete + self.stream.synchronize() + + with context: + if self.offload_to_disk_path: + self._onload_from_disk(current_stream) + else: + self._onload_from_memory(current_stream) + @torch.compiler.disable() def _offload_to_disk(self): # TODO: we can potentially optimize this code path by checking if the _all_ the desired