mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
enable quantcompile test on xpu (#11988)
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -18,10 +18,10 @@ import inspect
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
@slow
|
||||
class QuantCompileTests:
|
||||
@property
|
||||
@@ -51,7 +51,7 @@ class QuantCompileTests:
|
||||
return pipe
|
||||
|
||||
def _test_torch_compile(self, torch_dtype=torch.bfloat16):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to("cuda")
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype).to(torch_device)
|
||||
# `fullgraph=True` ensures no graph breaks
|
||||
pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
@@ -71,7 +71,7 @@ class QuantCompileTests:
|
||||
|
||||
pipe = self._init_pipeline(self.quantization_config, torch_dtype)
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch.device("cuda"),
|
||||
"onload_device": torch.device(torch_device),
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": use_stream,
|
||||
@@ -81,7 +81,7 @@ class QuantCompileTests:
|
||||
for name, component in pipe.components.items():
|
||||
if name != "transformer" and isinstance(component, torch.nn.Module):
|
||||
if torch.device(component.device).type == "cpu":
|
||||
component.to("cuda")
|
||||
component.to(torch_device)
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
@@ -236,7 +236,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
|
||||
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
|
||||
@@ -753,7 +753,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
|
||||
]
|
||||
|
||||
if TorchAoConfig._is_cuda_capability_atleast_8_9():
|
||||
if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
QUANTIZATION_TYPES_TO_TEST.extend([
|
||||
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
|
||||
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
|
||||
|
||||
Reference in New Issue
Block a user