mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
@@ -78,17 +78,21 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
|
||||
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
||||
# contiguous batching and inefficient with diffusers' batching,
|
||||
# so we disabled it now.
|
||||
|
||||
# elif qweight_type in MMVQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# elif qweight_type in MMQ_QUANT_TYPES:
|
||||
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
@@ -539,5 +543,10 @@ class GGUFLinear(nn.Linear):
|
||||
|
||||
def forward_cuda(self, inputs):
|
||||
quant_type = self.weight.quant_type
|
||||
return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
|
||||
orig_shape = inputs.shape
|
||||
inputs = inputs.view(-1, orig_shape[-1])
|
||||
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
|
||||
if self.bias is not None:
|
||||
output = output + self.bias.to(self.compute_dtype)
|
||||
return output.view(*orig_shape[:-1], -1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user