diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 2701040854..03521eadb2 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -78,17 +78,21 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T - # enable MMVQ in contiguous batching with batch_size=1 - if qweight_type in MMVQ_QUANT_TYPES: - y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) - # Use MMQ Kernel if it's available (standard + k-quants) - elif qweight_type in MMQ_QUANT_TYPES: - y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + + # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for + # contiguous batching and inefficient with diffusers' batching, + # so we disabled it now. + + # elif qweight_type in MMVQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + # elif qweight_type in MMQ_QUANT_TYPES: + # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) # If there is no available MMQ kernel, fallback to dequantize + elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) - weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) y = x @ weight.T else: # Raise an error if the quantization type is not supported. @@ -539,5 +543,10 @@ class GGUFLinear(nn.Linear): def forward_cuda(self, inputs): quant_type = self.weight.quant_type - return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + orig_shape = inputs.shape + inputs = inputs.view(-1, orig_shape[-1]) + output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) + if self.bias is not None: + output = output + self.bias.to(self.compute_dtype) + return output.view(*orig_shape[:-1], -1)