1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add gguf kernel support

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-07-05 17:47:06 +08:00
parent 425a715e35
commit 6c4d01def7
3 changed files with 89 additions and 1 deletions

View File

@@ -17,10 +17,11 @@ import inspect
from contextlib import nullcontext
import gguf
from gguf import GGMLQuantizationType as WeightType
import torch
import torch.nn as nn
from ...utils import is_accelerate_available
from ...utils import is_accelerate_available, is_kernels_available
if is_accelerate_available():
@@ -29,6 +30,76 @@ if is_accelerate_available():
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,
WeightType.Q4_1,
WeightType.Q5_0,
WeightType.Q5_1,
WeightType.Q8_0,
WeightType.Q8_1,
}
KQUANT_TYPES = {
WeightType.Q2_K,
WeightType.Q3_K,
WeightType.Q4_K,
WeightType.Q5_K,
WeightType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
WeightType.IQ1_M,
WeightType.IQ1_S,
WeightType.IQ2_XXS,
WeightType.IQ2_XS,
WeightType.IQ2_S,
WeightType.IQ3_XXS,
WeightType.IQ3_S,
WeightType.IQ4_XS,
WeightType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
# MMQ kernel for I-Matrix quantization.
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# enable MMVQ in contiguous batching with batch_size=1
if qweight_type in MMVQ_QUANT_TYPES:
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# Use MMQ Kernel if it's available (standard + k-quants)
elif qweight_type in MMQ_QUANT_TYPES:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")
return y
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
@@ -451,11 +522,22 @@ class GGUFLinear(nn.Linear):
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward_native(self, inputs):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
def forward_cuda(self, inputs):
quant_type = self.weight.quant_type
return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)

View File

@@ -76,6 +76,7 @@ from .import_utils import (
is_hpu_available,
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_librosa_available,

View File

@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
_transformers_available, _transformers_version = _is_package_available("transformers")
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
_kernels_available, _kernels_version = _is_package_available("kernels")
_inflect_available, _inflect_version = _is_package_available("inflect")
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -274,6 +275,10 @@ def is_accelerate_available():
return _accelerate_available
def is_kernels_available():
return _kernels_available
def is_k_diffusion_available():
return _k_diffusion_available