1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add compile + offload tests for GGUF.

This commit is contained in:
sayakpaul
2025-06-18 12:18:50 +05:30
parent 05e867784d
commit e8a3eec04f

View File

@@ -15,6 +15,7 @@ from diffusers import (
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
DiffusionPipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
@@ -28,11 +29,12 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_accelerate,
require_big_accelerator,
require_torch_version_greater,
require_gguf_version_greater_or_equal,
require_peft_backend,
torch_device,
)
from ..test_torch_compile_utils import QuantCompileTests
if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
@@ -577,3 +579,30 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
@require_torch_version_greater("2.7.1")
class GGUFCompileTests(QuantCompileTests):
torch_dtype = torch.bfloat16
quantization_config = GGUFQuantizationConfig(compute_dtype=torch_dtype)
gguf_ckpt = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
def _init_pipeline(self, *args, **kwargs):
transformer = SD3Transformer2DModel.from_single_file(
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium",
transformer=transformer,
torch_dtype=self.torch_dtype
)
return pipe
def test_torch_compile(self):
super()._test_torch_compile(quantization_config=self.quantization_config)
def test_torch_compile_with_cpu_offload(self):
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
def test_torch_compile_with_group_offload(self):
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)