1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

enable torchao test cases on XPU and switch to device agnostic APIs for test cases (#11654)

* enable torchao cases on XPU

Signed-off-by: Matrix YAO <matrix.yao@intel.com>

* device agnostic APIs

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* more

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* enable test_torch_compile_recompilation_and_graph_break on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* resolve comments

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Matrix YAO <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-06-11 17:47:06 +08:00
committed by GitHub
parent e27142ac64
commit 33e636cea5
30 changed files with 109 additions and 91 deletions

View File

@@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
@@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin):
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
if cls._is_cuda_capability_atleast_8_9():
if cls._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES
@@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin):
)
@staticmethod
def _is_cuda_capability_atleast_8_9() -> bool:
if not torch.cuda.is_available():
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.")
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
if major == 8:
return minor >= 9
return major >= 9
elif torch.xpu.is_available():
return True
else:
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()

View File

@@ -300,9 +300,7 @@ def require_torch_gpu(test_case):
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if not torch.cuda.is_available():
return unittest.skip(test_case)
else:
if torch.cuda.is_available():
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),

View File

@@ -21,6 +21,7 @@ import torch
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
slow,
@@ -162,13 +163,13 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_encode_decode(self):

View File

@@ -22,6 +22,7 @@ import torch
from diffusers import UNet2DModel
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
@@ -229,7 +230,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# two models don't need to stay in the device at the same time
del model_accelerate
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained(

View File

@@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import (
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
@@ -978,13 +977,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
@@ -994,13 +993,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)

View File

@@ -24,6 +24,7 @@ from transformers import AutoTokenizer, T5Config, T5EncoderModel
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_hf_hub_version_greater,
@@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_allegro(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -37,7 +37,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device
from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
@@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -45,7 +45,13 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
is_torch_version,
nightly,
torch_device,
)
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)

View File

@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
@@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
@@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_cogview3plus(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
@@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

View File

@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_numpy,
@@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")

View File

@@ -221,7 +221,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline

View File

@@ -25,6 +25,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
load_numpy,
@@ -135,7 +136,7 @@ class IFPipelineSlowTests(unittest.TestCase):
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(

View File

@@ -24,6 +24,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
floats_tensor,
@@ -151,7 +152,7 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase):
)
image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9
expected_image = load_numpy(

View File

@@ -224,7 +224,7 @@ class FluxPipelineFastTests(
@nightly
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -312,7 +312,7 @@ class FluxPipelineSlowTests(unittest.TestCase):
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"

View File

@@ -19,7 +19,7 @@ from diffusers.utils.testing_utils import (
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev"

View File

@@ -23,6 +23,7 @@ from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_accelerator,
@@ -310,12 +311,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_hunyuan_dit_1024(self):
generator = torch.Generator("cpu").manual_seed(0)

View File

@@ -27,6 +27,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
@@ -231,12 +232,12 @@ class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_attend_and_excite_fp16(self):
generator = torch.manual_seed(51)

View File

@@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism,
@@ -287,6 +288,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
output_type="np",
)
mem_bytes = torch.cuda.max_memory_allocated()
mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.65 GB is allocated
assert mem_bytes < 2.65 * 10**9

View File

@@ -233,7 +233,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -168,7 +168,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
@slow
@require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda
@pytest.mark.big_accelerator
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -35,6 +35,7 @@ from diffusers import (
UniPCMultistepScheduler,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
numpy_cosine_similarity_distance,
@@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_stable_diffusion_lcm(self):
torch.manual_seed(0)

View File

@@ -39,6 +39,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_image,
@@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_stable_diffusion_xl_img2img_playground(self):
torch.manual_seed(0)

View File

@@ -1218,13 +1218,13 @@ class PipelineFastTests(unittest.TestCase):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def dummy_image(self):
batch_size = 1

View File

@@ -21,9 +21,11 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -144,12 +146,12 @@ class WanPipelineIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_Wanx(self):

View File

@@ -30,13 +30,15 @@ from diffusers import (
)
from diffusers.models.attention_processor import Attention
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_synchronize,
enable_full_determinism,
is_torch_available,
is_torchao_available,
nightly,
numpy_cosine_similarity_distance,
require_torch,
require_torch_gpu,
require_torch_accelerator,
require_torchao_version_greater_or_equal,
slow,
torch_device,
@@ -61,7 +63,7 @@ if is_torchao_available():
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
@@ -79,7 +81,7 @@ class TorchAoConfigTest(unittest.TestCase):
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
with self.assertRaisesRegex(ValueError, "is not supported"):
_ = TorchAoConfig("uint8")
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"):
@@ -119,12 +121,12 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_components(
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe"
@@ -269,6 +271,7 @@ class TorchAoTest(unittest.TestCase):
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
device_map=f"{torch_device}:0",
)
weight = quantized_model.transformer_blocks[0].ff.net[2].weight
@@ -338,7 +341,7 @@ class TorchAoTest(unittest.TestCase):
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
with tempfile.TemporaryDirectory() as offload_folder:
quantization_config = TorchAoConfig("int4_weight_only", group_size=64)
@@ -359,7 +362,7 @@ class TorchAoTest(unittest.TestCase):
output = quantized_model(**inputs)[0]
output_slice = output.flatten()[-9:].detach().float().cpu().numpy()
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3)
def test_modules_to_not_convert(self):
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"])
@@ -518,14 +521,14 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None):
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs)
@@ -593,17 +596,17 @@ class TorchAoSerializationTest(unittest.TestCase):
)
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3)
def test_int_a8w8_cuda(self):
def test_int_a8w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {}
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551])
device = "cuda"
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
def test_int_a16w8_cuda(self):
def test_int_a16w8_accelerator(self):
quant_method, quant_method_kwargs = "int8_weight_only", {}
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
device = "cuda"
device = torch_device
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@@ -624,14 +627,14 @@ class TorchAoSerializationTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_components(self, quantization_config: TorchAoConfig):
# This is just for convenience, so that we can modify it at one place for custom environments and locally testing
@@ -713,8 +716,8 @@ class SlowTorchAoTests(unittest.TestCase):
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"])
self._test_quant_type(quantization_config, expected_slice)
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
backend_empty_cache(torch_device)
backend_synchronize(torch_device)
def test_serialization_int8wo(self):
quantization_config = TorchAoConfig("int8wo")
@@ -733,8 +736,8 @@ class SlowTorchAoTests(unittest.TestCase):
pipe.remove_all_hooks()
del pipe.transformer
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
backend_empty_cache(torch_device)
backend_synchronize(torch_device)
transformer = FluxTransformer2DModel.from_pretrained(
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False
)
@@ -783,14 +786,14 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_gpu
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def get_dummy_inputs(self, device: torch.device, seed: int = 0):
if str(device).startswith("mps"):

View File

@@ -16,8 +16,6 @@
import gc
import unittest
import torch
from diffusers import (
Lumina2Transformer2DModel,
)
@@ -66,9 +64,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

View File

@@ -16,8 +16,6 @@
import gc
import unittest
import torch
from diffusers import (
FluxTransformer2DModel,
)
@@ -64,9 +62,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)

View File

@@ -1,8 +1,6 @@
import gc
import unittest
import torch
from diffusers import (
SanaTransformer2DModel,
)
@@ -53,9 +51,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
def test_checkpoint_loading(self):
for ckpt_path in self.alternate_keys_ckpt_paths:
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model = self.model_class.from_single_file(ckpt_path)
del model
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)