From ba2ba9019f76fd96c532240ed07d3f98343e4041 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 6 Aug 2025 00:06:48 +0800 Subject: [PATCH] Add cuda kernel support for GGUF inference (#11869) * add gguf kernel support Signed-off-by: Isotr0py <2037008807@qq.com> * fix Signed-off-by: Isotr0py <2037008807@qq.com> * optimize Signed-off-by: Isotr0py <2037008807@qq.com> * update * update * update * update * update --------- Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: DN6 --- .github/workflows/nightly_tests.yml | 2 +- docs/source/en/quantization/gguf.md | 10 +++ src/diffusers/quantizers/gguf/utils.py | 95 +++++++++++++++++++++++++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 5 ++ src/diffusers/utils/testing_utils.py | 13 ++++ tests/quantization/gguf/test_gguf.py | 57 ++++++++++++++++ 7 files changed, 179 insertions(+), 4 deletions(-) diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 88a2af87c8..9216564093 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -333,7 +333,7 @@ jobs: additional_deps: ["peft"] - backend: "gguf" test_location: "gguf" - additional_deps: ["peft"] + additional_deps: ["peft", "kernels"] - backend: "torchao" test_location: "torchao" additional_deps: [] diff --git a/docs/source/en/quantization/gguf.md b/docs/source/en/quantization/gguf.md index aec0875c65..71321d5568 100644 --- a/docs/source/en/quantization/gguf.md +++ b/docs/source/en/quantization/gguf.md @@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0] image.save("flux-gguf.png") ``` +## Using Optimized CUDA Kernels with GGUF + +Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library: + +```shell +pip install -U kernels +``` + +Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`. + ## Supported Quantization Types - BF16 diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 41d3517129..3dd00b2ce3 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -12,15 +12,15 @@ # # See the License for the specific language governing permissions and # # limitations under the License. - import inspect +import os from contextlib import nullcontext import gguf import torch import torch.nn as nn -from ...utils import is_accelerate_available +from ...utils import is_accelerate_available, is_kernels_available if is_accelerate_available(): @@ -29,6 +29,82 @@ if is_accelerate_available(): from accelerate.hooks import add_hook_to_module, remove_hook_from_module +can_use_cuda_kernels = ( + os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"] + and torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 7 +) +if can_use_cuda_kernels and is_kernels_available(): + from kernels import get_kernel + + ops = get_kernel("Isotr0py/ggml") +else: + ops = None + +UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16} +STANDARD_QUANT_TYPES = { + gguf.GGMLQuantizationType.Q4_0, + gguf.GGMLQuantizationType.Q4_1, + gguf.GGMLQuantizationType.Q5_0, + gguf.GGMLQuantizationType.Q5_1, + gguf.GGMLQuantizationType.Q8_0, + gguf.GGMLQuantizationType.Q8_1, +} +KQUANT_TYPES = { + gguf.GGMLQuantizationType.Q2_K, + gguf.GGMLQuantizationType.Q3_K, + gguf.GGMLQuantizationType.Q4_K, + gguf.GGMLQuantizationType.Q5_K, + gguf.GGMLQuantizationType.Q6_K, +} +IMATRIX_QUANT_TYPES = { + gguf.GGMLQuantizationType.IQ1_M, + gguf.GGMLQuantizationType.IQ1_S, + gguf.GGMLQuantizationType.IQ2_XXS, + gguf.GGMLQuantizationType.IQ2_XS, + gguf.GGMLQuantizationType.IQ2_S, + gguf.GGMLQuantizationType.IQ3_XXS, + gguf.GGMLQuantizationType.IQ3_S, + gguf.GGMLQuantizationType.IQ4_XS, + gguf.GGMLQuantizationType.IQ4_NL, +} +# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. +# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add +# MMQ kernel for I-Matrix quantization. +DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES +MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES + + +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: + # there is no need to call any kernel for fp16/bf16 + if qweight_type in UNQUANTIZED_TYPES: + return x @ qweight.T + + # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for + # contiguous batching and inefficient with diffusers' batching, + # so we disabled it now. + + # elif qweight_type in MMVQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # elif qweight_type in MMQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + + # If there is no available MMQ kernel, fallback to dequantize + if qweight_type in DEQUANT_TYPES: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) + y = x @ weight.to(x.dtype).T + else: + # Raise an error if the quantization type is not supported. + # Might be useful if llama.cpp adds a new quantization type. + # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. + qweight_type = gguf.GGMLQuantizationType(qweight_type) + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") + return y.as_tensor() + + # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook def _create_accelerate_new_hook(old_hook): r""" @@ -451,11 +527,24 @@ class GGUFLinear(nn.Linear): ) -> None: super().__init__(in_features, out_features, bias, device) self.compute_dtype = compute_dtype + self.device = device - def forward(self, inputs): + def forward(self, inputs: torch.Tensor): + if ops is not None and self.weight.is_cuda and inputs.is_cuda: + return self.forward_cuda(inputs) + return self.forward_native(inputs) + + def forward_native(self, inputs: torch.Tensor): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) bias = self.bias.to(self.compute_dtype) if self.bias is not None else None output = torch.nn.functional.linear(inputs, weight, bias) return output + + def forward_cuda(self, inputs: torch.Tensor): + quant_type = self.weight.quant_type + output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + if self.bias is not None: + output += self.bias.to(self.compute_dtype) + return output diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cadcedb98a..75a2bdd13e 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -81,6 +81,7 @@ from .import_utils import ( is_invisible_watermark_available, is_k_diffusion_available, is_k_diffusion_version, + is_kernels_available, is_librosa_available, is_matplotlib_available, is_nltk_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index a27c2da648..d8b26bda46 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") _transformers_available, _transformers_version = _is_package_available("transformers") _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") +_kernels_available, _kernels_version = _is_package_available("kernels") _inflect_available, _inflect_version = _is_package_available("inflect") _unidecode_available, _unidecode_version = _is_package_available("unidecode") _k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion") @@ -277,6 +278,10 @@ def is_accelerate_available(): return _accelerate_available +def is_kernels_available(): + return _kernels_available + + def is_k_diffusion_available(): return _k_diffusion_available diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 3d9444975d..a0307c108a 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -36,6 +36,7 @@ from .import_utils import ( is_compel_available, is_flax_available, is_gguf_available, + is_kernels_available, is_note_seq_available, is_onnx_available, is_opencv_available, @@ -634,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version): return decorator +def require_kernels_version_greater_or_equal(kernels_version): + def decorator(test_case): + correct_kernels_version = is_kernels_available() and version.parse( + version.parse(importlib.metadata.version("kernels")).base_version + ) >= version.parse(kernels_version) + return unittest.skipUnless( + correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}." + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index ba41678eaa..e9d7034f03 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -30,8 +30,10 @@ from diffusers.utils.testing_utils import ( nightly, numpy_cosine_similarity_distance, require_accelerate, + require_accelerator, require_big_accelerator, require_gguf_version_greater_or_equal, + require_kernels_version_greater_or_equal, require_peft_backend, require_torch_version_greater, torch_device, @@ -41,11 +43,66 @@ from ..test_torch_compile_utils import QuantCompileTests if is_gguf_available(): + import gguf + from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter enable_full_determinism() +@nightly +@require_accelerate +@require_accelerator +@require_gguf_version_greater_or_equal("0.10.0") +@require_kernels_version_greater_or_equal("0.9.0") +class GGUFCudaKernelsTests(unittest.TestCase): + def setUp(self): + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + + def test_cuda_kernels_vs_native(self): + if torch_device != "cuda": + self.skipTest("CUDA kernels test requires CUDA device") + + from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels + + if not can_use_cuda_kernels: + self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)") + + test_quant_types = ["Q4_0", "Q4_K"] + test_shape = (1, 64, 512) # batch, seq_len, hidden_dim + compute_dtype = torch.bfloat16 + + for quant_type in test_quant_types: + qtype = getattr(gguf.GGMLQuantizationType, quant_type) + in_features, out_features = 512, 512 + + torch.manual_seed(42) + float_weight = torch.randn(out_features, in_features, dtype=torch.float32) + quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype) + weight_data = torch.from_numpy(quantized_data).to(device=torch_device) + weight = GGUFParameter(weight_data, quant_type=qtype) + + x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device) + + linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype) + linear.weight = weight + linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype)) + linear = linear.to(torch_device) + + with torch.no_grad(): + output_native = linear.forward_native(x) + output_cuda = linear.forward_cuda(x) + + assert torch.allclose(output_native, output_cuda, 1e-2), ( + f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" + ) + + @nightly @require_big_accelerator @require_accelerate