From e8a3eec04fc107ed7294ccfe50f486223672b12f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 18 Jun 2025 12:18:50 +0530 Subject: [PATCH] add compile + offload tests for GGUF. --- tests/quantization/gguf/test_gguf.py | 31 +++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index ae3900459d..c0e876c1a8 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -15,6 +15,7 @@ from diffusers import ( HiDreamImageTransformer2DModel, SD3Transformer2DModel, StableDiffusion3Pipeline, + DiffusionPipeline, ) from diffusers.utils import load_image from diffusers.utils.testing_utils import ( @@ -28,11 +29,12 @@ from diffusers.utils.testing_utils import ( numpy_cosine_similarity_distance, require_accelerate, require_big_accelerator, + require_torch_version_greater, require_gguf_version_greater_or_equal, require_peft_backend, torch_device, ) - +from ..test_torch_compile_utils import QuantCompileTests if is_gguf_available(): from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter @@ -577,3 +579,30 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): ).to(torch_device, self.torch_dtype), "timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype), } + + +@require_torch_version_greater("2.7.1") +class GGUFCompileTests(QuantCompileTests): + torch_dtype = torch.bfloat16 + quantization_config = GGUFQuantizationConfig(compute_dtype=torch_dtype) + gguf_ckpt = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf" + + def _init_pipeline(self, *args, **kwargs): + transformer = SD3Transformer2DModel.from_single_file( + self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype + ) + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-medium", + transformer=transformer, + torch_dtype=self.torch_dtype + ) + return pipe + + def test_torch_compile(self): + super()._test_torch_compile(quantization_config=self.quantization_config) + + def test_torch_compile_with_cpu_offload(self): + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + + def test_torch_compile_with_group_offload(self): + super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)