mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add compile + offload tests for GGUF.
This commit is contained in:
@@ -15,6 +15,7 @@ from diffusers import (
|
||||
HiDreamImageTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Pipeline,
|
||||
DiffusionPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
@@ -28,11 +29,12 @@ from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_big_accelerator,
|
||||
require_torch_version_greater,
|
||||
require_gguf_version_greater_or_equal,
|
||||
require_peft_backend,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_torch_compile_utils import QuantCompileTests
|
||||
|
||||
if is_gguf_available():
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
@@ -577,3 +579,30 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class GGUFCompileTests(QuantCompileTests):
|
||||
torch_dtype = torch.bfloat16
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=torch_dtype)
|
||||
gguf_ckpt = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
|
||||
|
||||
def _init_pipeline(self, *args, **kwargs):
|
||||
transformer = SD3Transformer2DModel.from_single_file(
|
||||
self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-medium",
|
||||
transformer=transformer,
|
||||
torch_dtype=self.torch_dtype
|
||||
)
|
||||
return pipe
|
||||
|
||||
def test_torch_compile(self):
|
||||
super()._test_torch_compile(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
|
||||
|
||||
def test_torch_compile_with_group_offload(self):
|
||||
super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
|
||||
|
||||
Reference in New Issue
Block a user