1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

enable dit integration cases on xpu (#11523)

* enable dit integration test on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-05-09 18:36:50 +08:00
committed by GitHub
parent 3c0a0129fe
commit d6bf268a4a

View File

@@ -21,7 +21,15 @@ import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DiTTransformer2DModel, DPMSolverMultistepScheduler
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, nightly, require_torch_gpu, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
require_torch_accelerator,
torch_device,
)
from ..pipeline_params import (
CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS,
@@ -107,23 +115,23 @@ class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
class DiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_dit_256(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.to("cuda")
pipe.to(torch_device)
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
@@ -139,7 +147,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
def test_dit_512(self):
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
pipe.to(torch_device)
words = ["vase", "umbrella"]
ids = pipe.get_label_ids(words)
@@ -152,4 +160,7 @@ class DiTPipelineIntegrationTests(unittest.TestCase):
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}_512.npy"
)
assert np.abs((expected_image - image).max()) < 1e-1
expected_slice = expected_image.flatten()
output_slice = image.flatten()
assert numpy_cosine_similarity_distance(expected_slice, output_slice) < 1e-2