From eef3d6595456e69a48989c2cc44c739792341e07 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 18 Apr 2025 07:27:41 +0800 Subject: [PATCH] enable 2 test cases on XPU (#11332) * enable 2 test cases on XPU Signed-off-by: YAO Matrix * Apply style fixes --------- Signed-off-by: YAO Matrix Co-authored-by: github-actions[bot] Co-authored-by: Dhruv Nair --- tests/quantization/bnb/test_mixed_int8.py | 4 +++- tests/quantization/utils.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 1049bfecba..a8aff679b5 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -523,13 +523,15 @@ class SlowBnb8bitTests(Base8bitTests): torch_dtype=torch.float16, device_map=torch_device, ) + # CUDA device placement works. + device = torch_device if torch_device != "rocm" else "cuda" pipeline_8bit = DiffusionPipeline.from_pretrained( self.model_name, transformer=transformer_8bit, text_encoder_3=text_encoder_3_8bit, torch_dtype=torch.float16, - ).to("cuda") + ).to(device) # Check if inference works. _ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 04ebf9e159..d458a3e6d5 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,4 +1,10 @@ from diffusers.utils import is_torch_available +from diffusers.utils.testing_utils import ( + backend_empty_cache, + backend_max_memory_allocated, + backend_reset_peak_memory_stats, + torch_device, +) if is_torch_available(): @@ -30,9 +36,9 @@ if is_torch_available(): @torch.no_grad() @torch.inference_mode() def get_memory_consumption_stat(model, inputs): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() + backend_reset_peak_memory_stats(torch_device) + backend_empty_cache(torch_device) model(**inputs) - max_memory_mem_allocated = torch.cuda.max_memory_allocated() - return max_memory_mem_allocated + max_mem_allocated = backend_max_memory_allocated(torch_device) + return max_mem_allocated