mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -91,18 +91,18 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in
|
||||
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
if qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
|
||||
y = x @ weight.T
|
||||
y = x @ weight.to(x.dtype).T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = gguf.GGMLQuantizationType(qweight_type)
|
||||
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y
|
||||
return y.as_tensor()
|
||||
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
|
||||
|
||||
@@ -73,15 +73,12 @@ class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
|
||||
for quant_type in test_quant_types:
|
||||
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
|
||||
|
||||
in_features, out_features = 512, 512
|
||||
total_elements = in_features * out_features
|
||||
n_blocks = total_elements // block_size
|
||||
weight_bytes = n_blocks * type_size
|
||||
|
||||
torch.manual_seed(42)
|
||||
weight_data = torch.randint(0, 256, (weight_bytes,), dtype=torch.uint8, device=torch_device)
|
||||
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
|
||||
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
|
||||
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
|
||||
weight = GGUFParameter(weight_data, quant_type=qtype)
|
||||
|
||||
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
|
||||
@@ -95,9 +92,9 @@ class GGUFCudaKernelsTests(unittest.TestCase):
|
||||
output_native = linear.forward_native(x)
|
||||
output_cuda = linear.forward_cuda(x)
|
||||
|
||||
# Compare outputs
|
||||
max_diff = torch.abs(output_cuda - output_native).max()
|
||||
assert max_diff < 1e-4, "GGUF CUDA Kernel Output is different from Native Output"
|
||||
assert torch.allclose(output_native, output_cuda, 1e-2), (
|
||||
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
|
||||
)
|
||||
|
||||
|
||||
@nightly
|
||||
|
||||
Reference in New Issue
Block a user