From ff690a132426152918d3b7eb5814a48ab773294a Mon Sep 17 00:00:00 2001 From: DN6 Date: Fri, 12 Dec 2025 11:01:36 +0530 Subject: [PATCH] update --- src/diffusers/hooks/group_offloading.py | 53 ++++++++++++++----------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 47f1f41996..26144470bb 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -156,38 +156,33 @@ class ModuleGroup: finally: pinned_dict = None - def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): + def _transfer_tensor_to_device(self, tensor, source_tensor): tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor.data.record_stream(default_stream) - def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None): + def _process_tensors_from_modules(self, pinned_memory=None): for group_module in self.modules: for param in group_module.parameters(): source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, default_stream) + self._transfer_tensor_to_device(param, source) for buffer in group_module.buffers(): source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, default_stream) + self._transfer_tensor_to_device(buffer, source) for param in self.parameters: source = pinned_memory[param] if pinned_memory else param.data - self._transfer_tensor_to_device(param, source, default_stream) + self._transfer_tensor_to_device(param, source) for buffer in self.buffers: source = pinned_memory[buffer] if pinned_memory else buffer.data - self._transfer_tensor_to_device(buffer, source, default_stream) + self._transfer_tensor_to_device(buffer, source) def _onload_from_disk(self): if self.stream is not None: - # Wait for previous Host->Device transfer to complete self.stream.synchronize() context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) - current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None with context: - # Load to CPU (if using streams) or directly to target device, pin, and async copy to device device = str(self.onload_device) if self.stream is None else "cpu" loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device) @@ -195,8 +190,6 @@ class ModuleGroup: for key, tensor_obj in self.key_to_tensor.items(): pinned_tensor = loaded_tensors[key].pin_memory() tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: - tensor_obj.data.record_stream(current_stream) else: onload_device = ( self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device @@ -207,45 +200,57 @@ class ModuleGroup: def _onload_from_memory(self): if self.stream is not None: - # Wait for previous Host->Device transfer to complete self.stream.synchronize() context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream) - default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None with context: if self.stream is not None: with self._pinned_memory_tensors() as pinned_memory: - self._process_tensors_from_modules(pinned_memory, default_stream=default_stream) + self._process_tensors_from_modules(pinned_memory) else: self._process_tensors_from_modules(None) def _offload_to_disk(self): - # TODO: we can potentially optimize this code path by checking if the _all_ the desired - # safetensor files exist on the disk and if so, skip this step entirely, reducing IO - # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not - # we perform a write. - # Check if the file has been saved in this session or if it already exists on disk. if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path): os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True) tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()} safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path) - # The group is now considered offloaded to disk for the rest of the session. self._is_offloaded_to_disk = True - # We do this to free up the RAM which is still holding the up tensor data. + if self.stream is not None: + if self.record_stream: + current_stream = self._torch_accelerator_module.current_stream() + for tensor_obj in self.tensor_to_key.keys(): + tensor_obj.data.record_stream(current_stream) + else: + self._torch_accelerator_module.current_stream().synchronize() + for tensor_obj in self.tensor_to_key.keys(): tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device) def _offload_to_memory(self): if self.stream is not None: - if not self.record_stream: + if self.record_stream: + current_stream = self._torch_accelerator_module.current_stream() + for group_module in self.modules: + for param in group_module.parameters(): + param.data.record_stream(current_stream) + for buffer in group_module.buffers(): + buffer.data.record_stream(current_stream) + for param in self.parameters: + param.data.record_stream(current_stream) + for buffer in self.buffers: + buffer.data.record_stream(current_stream) + else: self._torch_accelerator_module.current_stream().synchronize() for group_module in self.modules: for param in group_module.parameters(): param.data = self.cpu_param_dict[param] + for buffer in group_module.buffers(): + buffer.data = self.cpu_param_dict[buffer] for param in self.parameters: param.data = self.cpu_param_dict[param] for buffer in self.buffers: