diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 609c9ad15a..871faf076e 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") - if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9(): + if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9(): raise ValueError( f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." @@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin): QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) - if cls._is_cuda_capability_atleast_8_9(): + if cls._is_xpu_or_cuda_capability_atleast_8_9(): QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) return QUANTIZATION_TYPES @@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin): ) @staticmethod - def _is_cuda_capability_atleast_8_9() -> bool: - if not torch.cuda.is_available(): - raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.") - - major, minor = torch.cuda.get_device_capability() - if major == 8: - return minor >= 9 - return major >= 9 + def _is_xpu_or_cuda_capability_atleast_8_9() -> bool: + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + if major == 8: + return minor >= 9 + return major >= 9 + elif torch.xpu.is_available(): + return True + else: + raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.") def get_apply_tensor_subclass(self): TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index e19a9f83fd..ae5a6e6e91 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -300,9 +300,7 @@ def require_torch_gpu(test_case): def require_torch_cuda_compatibility(expected_compute_capability): def decorator(test_case): - if not torch.cuda.is_available(): - return unittest.skip(test_case) - else: + if torch.cuda.is_available(): current_compute_capability = get_torch_cuda_device_capability() return unittest.skipUnless( float(current_compute_capability) == float(expected_compute_capability), diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py index 77977a78d8..db87004fcb 100644 --- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py +++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py @@ -21,6 +21,7 @@ import torch from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, load_image, slow, @@ -162,13 +163,13 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) @torch.no_grad() def test_encode_decode(self): diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index 0e5fdc4bba..1a7959a877 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -22,6 +22,7 @@ import torch from diffusers import UNet2DModel from diffusers.utils import logging from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, require_torch_accelerator, @@ -229,7 +230,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): # two models don't need to stay in the device at the same time del model_accelerate - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() model_normal_load, _ = UNet2DModel.from_pretrained( diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index ab0dcbc1de..c8ed68c65b 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import ( require_peft_backend, require_torch_accelerator, require_torch_accelerator_with_fp16, - require_torch_gpu, skip_mps, slow, torch_all_close, @@ -978,13 +977,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) - @require_torch_gpu @parameterized.expand( [ ("hf-internal-testing/unet2d-sharded-dummy", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), ] ) + @require_torch_accelerator def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) @@ -994,13 +993,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) - @require_torch_gpu @parameterized.expand( [ ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), ] ) + @require_torch_accelerator def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index 30fdd68cfd..30a14ef7f5 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -24,6 +24,7 @@ from transformers import AutoTokenizer, T5Config, T5EncoderModel from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_hf_hub_version_greater, @@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_allegro(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py index aaf44985aa..340bc24adc 100644 --- a/tests/pipelines/audioldm/test_audioldm.py +++ b/tests/pipelines/audioldm/test_audioldm.py @@ -37,7 +37,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils import is_xformers_available -from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device +from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) @@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index a8f60fb6dc..14b5510fca 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -45,7 +45,13 @@ from diffusers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device +from diffusers.utils.testing_utils import ( + backend_empty_cache, + enable_full_determinism, + is_torch_version, + nightly, + torch_device, +) from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..test_pipelines_common import PipelineTesterMixin @@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): generator = torch.Generator(device=generator_device).manual_seed(seed) diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index a9de0ff05f..a6349c99c5 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_accelerator, @@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_cogvideox(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/cogview3/test_cogview3plus.py b/tests/pipelines/cogview3/test_cogview3plus.py index 79dffd230a..4eca68dd7b 100644 --- a/tests/pipelines/cogview3/test_cogview3plus.py +++ b/tests/pipelines/cogview3/test_cogview3plus.py @@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_accelerator, @@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_cogview3plus(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 100765ee34..0147d4a651 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo from diffusers.utils import load_image from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_numpy, @@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_canny(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index b06590e13c..63d5fd4660 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo from diffusers.utils import load_image from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_numpy, @@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_canny(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 1be15645ef..7880f744b9 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -221,7 +221,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3ControlNetPipeline diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py index 295b29f12e..8445229079 100644 --- a/tests/pipelines/deepfloyd_if/test_if.py +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -25,6 +25,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( backend_empty_cache, + backend_max_memory_allocated, backend_reset_max_memory_allocated, backend_reset_peak_memory_stats, load_numpy, @@ -135,7 +136,7 @@ class IFPipelineSlowTests(unittest.TestCase): image = output.images[0] - mem_bytes = torch.cuda.max_memory_allocated() + mem_bytes = backend_max_memory_allocated(torch_device) assert mem_bytes < 12 * 10**9 expected_image = load_numpy( diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py index da06dc3558..14271a9862 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -24,6 +24,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( backend_empty_cache, + backend_max_memory_allocated, backend_reset_max_memory_allocated, backend_reset_peak_memory_stats, floats_tensor, @@ -151,7 +152,7 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase): ) image = output.images[0] - mem_bytes = torch.cuda.max_memory_allocated() + mem_bytes = backend_max_memory_allocated(torch_device) assert mem_bytes < 12 * 10**9 expected_image = load_numpy( diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 646ad928ec..cbdf617d71 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -224,7 +224,7 @@ class FluxPipelineFastTests( @nightly @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class FluxPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-schnell" @@ -312,7 +312,7 @@ class FluxPipelineSlowTests(unittest.TestCase): @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class FluxIPAdapterPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-dev" diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py index 1f204add1c..b8f36dfd3c 100644 --- a/tests/pipelines/flux/test_pipeline_flux_redux.py +++ b/tests/pipelines/flux/test_pipeline_flux_redux.py @@ -19,7 +19,7 @@ from diffusers.utils.testing_utils import ( @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class FluxReduxSlowTests(unittest.TestCase): pipeline_class = FluxPriorReduxPipeline repo_id = "black-forest-labs/FLUX.1-Redux-dev" diff --git a/tests/pipelines/hunyuandit/test_hunyuan_dit.py b/tests/pipelines/hunyuandit/test_hunyuan_dit.py index 66453b73b0..05c94262ab 100644 --- a/tests/pipelines/hunyuandit/test_hunyuan_dit.py +++ b/tests/pipelines/hunyuandit/test_hunyuan_dit.py @@ -23,6 +23,7 @@ from transformers import AutoTokenizer, BertModel, T5EncoderModel from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, numpy_cosine_similarity_distance, require_torch_accelerator, @@ -310,12 +311,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_hunyuan_dit_1024(self): generator = torch.Generator("cpu").manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index c66491b15c..8399e57bfb 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -27,6 +27,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, load_numpy, nightly, numpy_cosine_similarity_distance, @@ -231,12 +232,12 @@ class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_attend_and_excite_fp16(self): generator = torch.manual_seed(51) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 2feeaaf11c..ff4a33abf8 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers.utils.testing_utils import ( backend_empty_cache, + backend_max_memory_allocated, backend_reset_max_memory_allocated, backend_reset_peak_memory_stats, enable_full_determinism, @@ -287,6 +288,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): output_type="np", ) - mem_bytes = torch.cuda.max_memory_allocated() + mem_bytes = backend_max_memory_allocated(torch_device) # make sure that less than 2.65 GB is allocated assert mem_bytes < 2.65 * 10**9 diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 8e2fa77fc0..577ac4ebdd 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -233,7 +233,7 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class StableDiffusion3PipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Pipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index 80bb35a08e..f5b5e63a81 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -168,7 +168,7 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte @slow @require_big_accelerator -@pytest.mark.big_gpu_with_torch_cuda +@pytest.mark.big_accelerator class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index a41e7dc7f3..11f08c8820 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -35,6 +35,7 @@ from diffusers import ( UniPCMultistepScheduler, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, load_image, numpy_cosine_similarity_distance, @@ -940,12 +941,12 @@ class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_stable_diffusion_lcm(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 9a141634a3..7d19d745a2 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -39,6 +39,7 @@ from diffusers import ( UNet2DConditionModel, ) from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, floats_tensor, load_image, @@ -670,12 +671,12 @@ class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def test_stable_diffusion_xl_img2img_playground(self): torch.manual_seed(0) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 65718a2545..c4db662784 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1218,13 +1218,13 @@ class PipelineFastTests(unittest.TestCase): # clean up the VRAM before each test super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): # clean up the VRAM after each test super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def dummy_image(self): batch_size = 1 diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py index a162e6841d..e3a153dd19 100644 --- a/tests/pipelines/wan/test_wan.py +++ b/tests/pipelines/wan/test_wan.py @@ -21,9 +21,11 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel from diffusers.utils.testing_utils import ( + backend_empty_cache, enable_full_determinism, require_torch_accelerator, slow, + torch_device, ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -144,12 +146,12 @@ class WanPipelineIntegrationTests(unittest.TestCase): def setUp(self): super().setUp() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def tearDown(self): super().tearDown() gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) @unittest.skip("TODO: test needs to be implemented") def test_Wanx(self): diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py index 0e671307dd..743da17356 100644 --- a/tests/quantization/torchao/test_torchao.py +++ b/tests/quantization/torchao/test_torchao.py @@ -30,13 +30,15 @@ from diffusers import ( ) from diffusers.models.attention_processor import Attention from diffusers.utils.testing_utils import ( + backend_empty_cache, + backend_synchronize, enable_full_determinism, is_torch_available, is_torchao_available, nightly, numpy_cosine_similarity_distance, require_torch, - require_torch_gpu, + require_torch_accelerator, require_torchao_version_greater_or_equal, slow, torch_device, @@ -61,7 +63,7 @@ if is_torchao_available(): @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") class TorchAoConfigTest(unittest.TestCase): def test_to_dict(self): @@ -79,7 +81,7 @@ class TorchAoConfigTest(unittest.TestCase): Test kwargs validations in TorchAoConfig """ _ = TorchAoConfig("int4_weight_only") - with self.assertRaisesRegex(ValueError, "is not supported yet"): + with self.assertRaisesRegex(ValueError, "is not supported"): _ = TorchAoConfig("uint8") with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): @@ -119,12 +121,12 @@ class TorchAoConfigTest(unittest.TestCase): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") class TorchAoTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_components( self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" @@ -269,6 +271,7 @@ class TorchAoTest(unittest.TestCase): subfolder="transformer", quantization_config=quantization_config, torch_dtype=torch.bfloat16, + device_map=f"{torch_device}:0", ) weight = quantized_model.transformer_blocks[0].ff.net[2].weight @@ -338,7 +341,7 @@ class TorchAoTest(unittest.TestCase): output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) with tempfile.TemporaryDirectory() as offload_folder: quantization_config = TorchAoConfig("int4_weight_only", group_size=64) @@ -359,7 +362,7 @@ class TorchAoTest(unittest.TestCase): output = quantized_model(**inputs)[0] output_slice = output.flatten()[-9:].detach().float().cpu().numpy() - self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) + self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 2e-3) def test_modules_to_not_convert(self): quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) @@ -518,14 +521,14 @@ class TorchAoTest(unittest.TestCase): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") class TorchAoSerializationTest(unittest.TestCase): model_name = "hf-internal-testing/tiny-flux-pipe" def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) @@ -593,17 +596,17 @@ class TorchAoSerializationTest(unittest.TestCase): ) self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) - def test_int_a8w8_cuda(self): + def test_int_a8w8_accelerator(self): quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) - device = "cuda" + device = torch_device self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) - def test_int_a16w8_cuda(self): + def test_int_a16w8_accelerator(self): quant_method, quant_method_kwargs = "int8_weight_only", {} expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) - device = "cuda" + device = torch_device self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) @@ -624,14 +627,14 @@ class TorchAoSerializationTest(unittest.TestCase): # Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") @slow @nightly class SlowTorchAoTests(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_components(self, quantization_config: TorchAoConfig): # This is just for convenience, so that we can modify it at one place for custom environments and locally testing @@ -713,8 +716,8 @@ class SlowTorchAoTests(unittest.TestCase): quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) self._test_quant_type(quantization_config, expected_slice) gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + backend_empty_cache(torch_device) + backend_synchronize(torch_device) def test_serialization_int8wo(self): quantization_config = TorchAoConfig("int8wo") @@ -733,8 +736,8 @@ class SlowTorchAoTests(unittest.TestCase): pipe.remove_all_hooks() del pipe.transformer gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + backend_empty_cache(torch_device) + backend_synchronize(torch_device) transformer = FluxTransformer2DModel.from_pretrained( tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False ) @@ -783,14 +786,14 @@ class SlowTorchAoTests(unittest.TestCase): @require_torch -@require_torch_gpu +@require_torch_accelerator @require_torchao_version_greater_or_equal("0.7.0") @slow @nightly class SlowTorchAoPreserializedModelTests(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) def get_dummy_inputs(self, device: torch.device, seed: int = 0): if str(device).startswith("mps"): diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py index d3ffd4fc3a..2ac681897d 100644 --- a/tests/single_file/test_lumina2_transformer.py +++ b/tests/single_file/test_lumina2_transformer.py @@ -16,8 +16,6 @@ import gc import unittest -import torch - from diffusers import ( Lumina2Transformer2DModel, ) @@ -66,9 +64,9 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase): def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model = self.model_class.from_single_file(ckpt_path) del model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py index bf11faaa9c..81779cf8fa 100644 --- a/tests/single_file/test_model_flux_transformer_single_file.py +++ b/tests/single_file/test_model_flux_transformer_single_file.py @@ -16,8 +16,6 @@ import gc import unittest -import torch - from diffusers import ( FluxTransformer2DModel, ) @@ -64,9 +62,9 @@ class FluxTransformer2DModelSingleFileTests(unittest.TestCase): def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model = self.model_class.from_single_file(ckpt_path) del model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py index 802ca37abf..e74c5be6ff 100644 --- a/tests/single_file/test_sana_transformer.py +++ b/tests/single_file/test_sana_transformer.py @@ -1,8 +1,6 @@ import gc import unittest -import torch - from diffusers import ( SanaTransformer2DModel, ) @@ -53,9 +51,9 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase): def test_checkpoint_loading(self): for ckpt_path in self.alternate_keys_ckpt_paths: - torch.cuda.empty_cache() + backend_empty_cache(torch_device) model = self.model_class.from_single_file(ckpt_path) del model gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device)