1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2025-07-24 06:30:12 +02:00
parent de1fb4b615
commit db94e2b5a7
2 changed files with 9 additions and 12 deletions

View File

@@ -91,18 +91,18 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
if qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.T
y = x @ weight.to(x.dtype).T
else:
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y
return y.as_tensor()
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook

View File

@@ -73,15 +73,12 @@ class GGUFCudaKernelsTests(unittest.TestCase):
for quant_type in test_quant_types:
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
in_features, out_features = 512, 512
total_elements = in_features * out_features
n_blocks = total_elements // block_size
weight_bytes = n_blocks * type_size
torch.manual_seed(42)
weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device)
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
weight = GGUFParameter(weight_data, quant_type=qtype)
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
@@ -95,9 +92,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
output_native = linear.forward_native(x)
output_cuda = linear.forward_cuda(x)
# Compare outputs
max_diff = torch.abs(output_cuda - output_native).max()
assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output"
assert torch.allclose(output_native, output_cuda, 1e-2), (
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
)
@nightly