1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-07-06 01:47:13 +08:00
parent 66bd237bc5
commit e46571a7aa
2 changed files with 32 additions and 38 deletions

View File

@@ -17,7 +17,6 @@ import inspect
from contextlib import nullcontext
import gguf
from gguf import GGMLQuantizationType as WeightType
import torch
import torch.nn as nn
@@ -33,37 +32,37 @@ if is_accelerate_available():
can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
if can_use_cuda_kernels and is_kernels_available():
from kernels import get_kernel
ops = get_kernel("Isotr0py/ggml")
else:
ops = None
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,
WeightType.Q4_1,
WeightType.Q5_0,
WeightType.Q5_1,
WeightType.Q8_0,
WeightType.Q8_1,
gguf.GGMLQuantizationType.Q4_0,
gguf.GGMLQuantizationType.Q4_1,
gguf.GGMLQuantizationType.Q5_0,
gguf.GGMLQuantizationType.Q5_1,
gguf.GGMLQuantizationType.Q8_0,
gguf.GGMLQuantizationType.Q8_1,
}
KQUANT_TYPES = {
WeightType.Q2_K,
WeightType.Q3_K,
WeightType.Q4_K,
WeightType.Q5_K,
WeightType.Q6_K,
gguf.GGMLQuantizationType.Q2_K,
gguf.GGMLQuantizationType.Q3_K,
gguf.GGMLQuantizationType.Q4_K,
gguf.GGMLQuantizationType.Q5_K,
gguf.GGMLQuantizationType.Q6_K,
}
IMATRIX_QUANT_TYPES = {
WeightType.IQ1_M,
WeightType.IQ1_S,
WeightType.IQ2_XXS,
WeightType.IQ2_XS,
WeightType.IQ2_S,
WeightType.IQ3_XXS,
WeightType.IQ3_S,
WeightType.IQ4_XS,
WeightType.IQ4_NL,
gguf.GGMLQuantizationType.IQ1_M,
gguf.GGMLQuantizationType.IQ1_S,
gguf.GGMLQuantizationType.IQ2_XXS,
gguf.GGMLQuantizationType.IQ2_XS,
gguf.GGMLQuantizationType.IQ2_S,
gguf.GGMLQuantizationType.IQ3_XXS,
gguf.GGMLQuantizationType.IQ3_S,
gguf.GGMLQuantizationType.IQ4_XS,
gguf.GGMLQuantizationType.IQ4_NL,
}
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
@@ -73,8 +72,7 @@ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
return x @ qweight.T
@@ -87,8 +85,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
# elif qweight_type in MMQ_QUANT_TYPES:
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
# If there is no available MMQ kernel, fallback to dequantize
# If there is no available MMQ kernel, fallback to dequantize
elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
@@ -98,9 +96,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
# Raise an error if the quantization type is not supported.
# Might be useful if llama.cpp adds a new quantization type.
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
qweight_type = WeightType(qweight_type)
raise NotImplementedError(
f"Unsupported GGUF quantization type: {qweight_type}")
qweight_type = gguf.GGMLQuantizationType(qweight_type)
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
return y
@@ -528,12 +525,12 @@ class GGUFLinear(nn.Linear):
self.compute_dtype = compute_dtype
self.device = device
def forward(self, inputs):
def forward(self, inputs: torch.Tensor):
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
return self.forward_cuda(inputs)
return self.forward_native(inputs)
def forward_native(self, inputs):
def forward_native(self, inputs: torch.Tensor):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
@@ -541,12 +538,9 @@ class GGUFLinear(nn.Linear):
output = torch.nn.functional.linear(inputs, weight, bias)
return output
def forward_cuda(self, inputs):
def forward_cuda(self, inputs: torch.Tensor):
quant_type = self.weight.quant_type
orig_shape = inputs.shape
inputs = inputs.view(-1, orig_shape[-1])
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
if self.bias is not None:
output = output + self.bias.to(self.compute_dtype)
return output.view(*orig_shape[:-1], -1)
output += self.bias.to(self.compute_dtype)
return output

View File

@@ -76,9 +76,9 @@ from .import_utils import (
is_hpu_available,
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_kernels_available,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,