mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
make autoencoders. controlnet_flux and wan_transformer3d_single_file pass on xpu (#11461)
* make autoencoders. controlnet_flux and wan_transformer3d_single_file pass on XPU Signed-off-by: Yao Matrix <matrix.yao@intel.com> * Apply style fixes --------- Signed-off-by: Yao Matrix <matrix.yao@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -55,7 +55,7 @@ class ModuleGroup:
|
||||
parameters: Optional[List[torch.nn.Parameter]] = None,
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
onload_self: bool = True,
|
||||
@@ -115,8 +115,13 @@ class ModuleGroup:
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
current_stream = torch.cuda.current_stream() if self.record_stream else None
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
|
||||
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
@@ -162,9 +167,15 @@ class ModuleGroup:
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
else torch.cuda
|
||||
)
|
||||
if self.stream is not None:
|
||||
if not self.record_stream:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch_accelerator_module.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
@@ -429,8 +440,10 @@ def apply_group_offloading(
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
stream = torch.cuda.Stream()
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
stream = torch.Stream()
|
||||
else:
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device.")
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
@@ -572,7 +585,7 @@ def _apply_group_offloading_leaf_level(
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
) -> None:
|
||||
@@ -592,7 +605,7 @@ def _apply_group_offloading_leaf_level(
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`, *optional*):
|
||||
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
|
||||
Reference in New Issue
Block a user