1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

xpu enabling for 4 cases (#12345)

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-09-23 21:01:45 -07:00
committed by GitHub
parent 9ef118509e
commit 7a58734994
2 changed files with 11 additions and 3 deletions

View File

@@ -25,6 +25,7 @@ from ..utils import (
is_accelerate_available,
logging,
)
from ..utils.torch_utils import get_device
if is_accelerate_available():
@@ -161,7 +162,9 @@ class AutoOffloadStrategy:
current_module_size = model.get_memory_footprint()
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
mem_on_device = device_module.mem_get_info(execution_device.index)[0]
mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
@@ -301,7 +304,7 @@ class ComponentsManager:
cm.add("vae", vae_model, collection="sdxl")
# Enable auto offloading
cm.enable_auto_cpu_offload(device="cuda")
cm.enable_auto_cpu_offload()
# Retrieve components
unet = cm.get_one(name="unet", collection="sdxl")
@@ -490,6 +493,8 @@ class ComponentsManager:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.xpu.is_available():
torch.xpu.empty_cache()
# YiYi TODO: rename to search_components for now, may remove this method
def search_components(
@@ -678,7 +683,7 @@ class ComponentsManager:
return get_return_dict(matches, return_dict_with_names)
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
"""
Enable automatic CPU offloading for all components.
@@ -704,6 +709,8 @@ class ComponentsManager:
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
if device is None:
device = get_device()
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")

View File

@@ -253,6 +253,7 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
expected_slices = Expectations(
{
("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
("xpu", 3): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
}
)
# fmt: on