1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-07-06 01:00:01 +08:00
parent 6c4d01def7
commit 66bd237bc5

View File

@@ -78,17 +78,21 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
# enable MMVQ in contiguous batching with batch_size=1
if qweight_type in MMVQ_QUANT_TYPES:
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# Use MMQ Kernel if it's available (standard + k-quants)
elif qweight_type in MMQ_QUANT_TYPES:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,
# so we disabled it now.
# elif qweight_type in MMVQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.T
else:
# Raise an error if the quantization type is not supported.
@@ -539,5 +543,10 @@ class GGUFLinear(nn.Linear):
def forward_cuda(self, inputs):
quant_type = self.weight.quant_type
return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
orig_shape = inputs.shape
inputs = inputs.view(-1, orig_shape[-1])
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output = output + self.bias.to(self.compute_dtype)
return output.view(*orig_shape[:-1], -1)