1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

enable lora cases on XPU (#11506)

* enable lora cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* remove hunyuanvideo xpu expectation

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-05-06 17:29:50 +08:00
committed by GitHub
parent d7ffe60166
commit 8c661ea586
5 changed files with 44 additions and 34 deletions

View File

@@ -31,13 +31,14 @@ from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, Flux
from diffusers.utils import load_image, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
floats_tensor,
is_peft_available,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_peft_backend,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -809,10 +810,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@slow
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -827,7 +828,7 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
@@ -836,13 +837,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
del self.pipeline
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
# Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
self.pipeline = self.pipeline.to(torch_device)
@@ -956,10 +957,10 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
@@ -969,17 +970,17 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
self.pipeline = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora(self, lora_ckpt_id):

View File

@@ -28,13 +28,16 @@ from diffusers import (
HunyuanVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
floats_tensor,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_peft_backend,
require_torch_gpu,
require_torch_accelerator,
skip_mps,
torch_device,
)
@@ -192,10 +195,10 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
@@ -210,7 +213,7 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
@@ -218,13 +221,13 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
)
self.pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16
).to("cuda")
).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
@@ -249,8 +252,13 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
out_slice = np.concatenate((out[:8], out[-8:]))
# fmt: off
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
expected_slices = Expectations(
{
("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
}
)
# fmt: on
expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)

View File

@@ -93,12 +93,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
# Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits).

View File

@@ -34,7 +34,7 @@ from diffusers.utils.testing_utils import (
is_flaky,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_peft_backend,
require_torch_accelerator,
torch_device,
@@ -138,8 +138,8 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@nightly
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -37,12 +37,13 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
is_flaky,
load_image,
nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -105,12 +106,12 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
@@ -119,18 +120,18 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
@slow
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_sdxl_1_0_lora(self):
generator = torch.Generator("cpu").manual_seed(0)