mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
enable group_offload cases and quanto cases on XPU (#11405)
* enable group_offload cases and quanto cases on XPU Signed-off-by: YAO Matrix <matrix.yao@intel.com> * use backend APIs Signed-off-by: Yao Matrix <matrix.yao@intel.com> * fix style Signed-off-by: Yao Matrix <matrix.yao@intel.com> --------- Signed-off-by: YAO Matrix <matrix.yao@intel.com> Signed-off-by: Yao Matrix <matrix.yao@intel.com>
This commit is contained in:
@@ -53,7 +53,7 @@ from diffusers.utils.testing_utils import (
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_accelerator,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
@@ -2212,7 +2212,7 @@ class PipelineTesterMixin:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference(self):
|
||||
if not self.test_group_offloading:
|
||||
return
|
||||
|
||||
@@ -6,10 +6,13 @@ from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
|
||||
from diffusers.models.attention_processor import Attention
|
||||
from diffusers.utils import is_optimum_quanto_available, is_torch_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
backend_reset_peak_memory_stats,
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_big_accelerator,
|
||||
require_torch_cuda_compatibility,
|
||||
torch_device,
|
||||
)
|
||||
@@ -23,9 +26,11 @@ if is_torch_available():
|
||||
|
||||
from ..utils import LoRALayer, get_memory_consumption_stat
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_big_accelerator
|
||||
@require_accelerate
|
||||
class QuantoBaseTesterMixin:
|
||||
model_id = None
|
||||
@@ -39,13 +44,13 @@ class QuantoBaseTesterMixin:
|
||||
_test_torch_compile = False
|
||||
|
||||
def setUp(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
def tearDown(self):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
gc.collect()
|
||||
|
||||
def get_dummy_init_kwargs(self):
|
||||
@@ -89,7 +94,7 @@ class QuantoBaseTesterMixin:
|
||||
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module
|
||||
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
model.to("cuda")
|
||||
model.to(torch_device)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
@@ -107,7 +112,7 @@ class QuantoBaseTesterMixin:
|
||||
init_kwargs.update({"quantization_config": quantization_config})
|
||||
|
||||
model = self.model_cls.from_pretrained(**init_kwargs)
|
||||
model.to("cuda")
|
||||
model.to(torch_device)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if name in self.modules_to_not_convert:
|
||||
@@ -122,7 +127,8 @@ class QuantoBaseTesterMixin:
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device` and `dtype`
|
||||
model.to(device="cuda:0", dtype=torch.float16)
|
||||
device_0 = f"{torch_device}:0"
|
||||
model.to(device=device_0, dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
@@ -133,7 +139,7 @@ class QuantoBaseTesterMixin:
|
||||
model.half()
|
||||
|
||||
# This should work
|
||||
model.to("cuda")
|
||||
model.to(torch_device)
|
||||
|
||||
def test_serialization(self):
|
||||
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
|
||||
|
||||
Reference in New Issue
Block a user