diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index edbc60abf5..aa6a2818d1 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -91,18 +91,18 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) # If there is no available MMQ kernel, fallback to dequantize - elif qweight_type in DEQUANT_TYPES: + if qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) weight = ops.ggml_dequantize(qweight, qweight_type, *shape) - y = x @ weight.T + y = x @ weight.to(x.dtype).T else: # Raise an error if the quantization type is not supported. # Might be useful if llama.cpp adds a new quantization type. # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. qweight_type = gguf.GGMLQuantizationType(qweight_type) raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") - return y + return y.as_tensor() # Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index aa558b3e82..a03efdd2be 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -73,15 +73,12 @@ class GGUFCudaKernelsTests(unittest.TestCase): for quant_type in test_quant_types: qtype = getattr(gguf.GGMLQuantizationType, quant_type) - block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] - in_features, out_features = 512, 512 - total_elements = in_features * out_features - n_blocks = total_elements // block_size - weight_bytes = n_blocks * type_size torch.manual_seed(42) - weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device) + float_weight = torch.randn(out_features, in_features, dtype=torch.float32) + quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype) + weight_data = torch.from_numpy(quantized_data).to(device=torch_device) weight = GGUFParameter(weight_data, quant_type=qtype) x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device) @@ -95,9 +92,9 @@ class GGUFCudaKernelsTests(unittest.TestCase): output_native = linear.forward_native(x) output_cuda = linear.forward_cuda(x) - # Compare outputs - max_diff = torch.abs(output_cuda - output_native).max() - assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output" + assert torch.allclose(output_native, output_cuda, 1e-2), ( + f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}" + ) @nightly