1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

enable group_offloading and PipelineDeviceAndDtypeStabilityTests on XPU, all passed (#11620)

* enable group_offloading and PipelineDeviceAndDtypeStabilityTests on XPU,
all passed

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* fix

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

---------

Signed-off-by: Matrix YAO <matrix.yao@intel.com>
Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
Yao Matrix
2025-05-30 14:00:37 +08:00
committed by GitHub
parent 3651bdb766
commit a7aa8bf28a
2 changed files with 31 additions and 24 deletions

View File

@@ -22,7 +22,13 @@ from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
from diffusers.utils.import_utils import compare_versions
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
require_torch_accelerator,
torch_device,
)
class DummyBlock(torch.nn.Module):
@@ -107,7 +113,7 @@ class DummyPipeline(DiffusionPipeline):
return x
@require_torch_gpu
@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
in_features = 64
hidden_features = 256
@@ -125,8 +131,8 @@ class GroupOffloadTests(unittest.TestCase):
del self.model
del self.input
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
def get_model(self):
torch.manual_seed(0)
@@ -141,8 +147,8 @@ class GroupOffloadTests(unittest.TestCase):
@torch.no_grad()
def run_forward(model):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
@@ -152,7 +158,7 @@ class GroupOffloadTests(unittest.TestCase):
)
model.eval()
output = model(self.input)[0].cpu()
max_memory_allocated = torch.cuda.max_memory_allocated()
max_memory_allocated = backend_max_memory_allocated(torch_device)
return output, max_memory_allocated
self.model.to(torch_device)
@@ -187,10 +193,10 @@ class GroupOffloadTests(unittest.TestCase):
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
# Memory assertions - offloading should reduce memory usage
self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline)
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
def test_warning_logged_if_group_offloaded_module_moved_to_cuda(self):
if torch.device(torch_device).type != "cuda":
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.models.modeling_utils")
@@ -199,8 +205,8 @@ class GroupOffloadTests(unittest.TestCase):
self.model.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
def test_warning_logged_if_group_offloaded_pipe_moved_to_cuda(self):
if torch.device(torch_device).type != "cuda":
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
pipe = DummyPipeline(self.model)
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
@@ -210,19 +216,20 @@ class GroupOffloadTests(unittest.TestCase):
pipe.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
def test_error_raised_if_streams_used_and_no_cuda_device(self):
original_is_available = torch.cuda.is_available
torch.cuda.is_available = lambda: False
def test_error_raised_if_streams_used_and_no_accelerator_device(self):
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
original_is_available = torch_accelerator_module.is_available
torch_accelerator_module.is_available = lambda: False
with self.assertRaises(ValueError):
self.model.enable_group_offload(
onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
)
torch.cuda.is_available = original_is_available
torch_accelerator_module.is_available = original_is_available
def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
self.model.enable_group_offload(onload_device=torch.device("cuda"))
self.model.enable_group_offload(onload_device=torch.device(torch_device))
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
@@ -249,7 +256,7 @@ class GroupOffloadTests(unittest.TestCase):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
if torch.device(torch_device).type != "cuda":
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
model = DummyModelWithMultipleBlocks(
in_features=self.in_features,

View File

@@ -19,7 +19,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
from diffusers.utils.testing_utils import require_torch_accelerator, torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -850,9 +850,9 @@ class ProgressBarTests(unittest.TestCase):
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
@require_torch_gpu
@require_torch_accelerator
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
expected_pipe_device = torch.device("cuda:0")
expected_pipe_device = torch.device(f"{torch_device}:0")
expected_pipe_dtype = torch.float64
def get_dummy_components_image_generation(self):
@@ -921,8 +921,8 @@ class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
pipe.to(device=torch_device, dtype=torch.float32)
pipe.unet.to(device="cpu")
pipe.vae.to(device="cuda")
pipe.text_encoder.to(device="cuda:0")
pipe.vae.to(device=torch_device)
pipe.text_encoder.to(device=f"{torch_device}:0")
pipe_device = pipe.device