From e46571a7aad95b2a4efc10d076740ad260e129fc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sun, 6 Jul 2025 01:47:13 +0800 Subject: [PATCH] optimize Signed-off-by: Isotr0py <2037008807@qq.com> --- src/diffusers/quantizers/gguf/utils.py | 68 ++++++++++++-------------- src/diffusers/utils/__init__.py | 2 +- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 03521eadb2..31f6ec3e73 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -17,7 +17,6 @@ import inspect from contextlib import nullcontext import gguf -from gguf import GGMLQuantizationType as WeightType import torch import torch.nn as nn @@ -33,37 +32,37 @@ if is_accelerate_available(): can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7 if can_use_cuda_kernels and is_kernels_available(): from kernels import get_kernel + ops = get_kernel("Isotr0py/ggml") else: ops = None - -UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} +UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16} STANDARD_QUANT_TYPES = { - WeightType.Q4_0, - WeightType.Q4_1, - WeightType.Q5_0, - WeightType.Q5_1, - WeightType.Q8_0, - WeightType.Q8_1, + gguf.GGMLQuantizationType.Q4_0, + gguf.GGMLQuantizationType.Q4_1, + gguf.GGMLQuantizationType.Q5_0, + gguf.GGMLQuantizationType.Q5_1, + gguf.GGMLQuantizationType.Q8_0, + gguf.GGMLQuantizationType.Q8_1, } KQUANT_TYPES = { - WeightType.Q2_K, - WeightType.Q3_K, - WeightType.Q4_K, - WeightType.Q5_K, - WeightType.Q6_K, + gguf.GGMLQuantizationType.Q2_K, + gguf.GGMLQuantizationType.Q3_K, + gguf.GGMLQuantizationType.Q4_K, + gguf.GGMLQuantizationType.Q5_K, + gguf.GGMLQuantizationType.Q6_K, } IMATRIX_QUANT_TYPES = { - WeightType.IQ1_M, - WeightType.IQ1_S, - WeightType.IQ2_XXS, - WeightType.IQ2_XS, - WeightType.IQ2_S, - WeightType.IQ3_XXS, - WeightType.IQ3_S, - WeightType.IQ4_XS, - WeightType.IQ4_NL, + gguf.GGMLQuantizationType.IQ1_M, + gguf.GGMLQuantizationType.IQ1_S, + gguf.GGMLQuantizationType.IQ2_XXS, + gguf.GGMLQuantizationType.IQ2_XS, + gguf.GGMLQuantizationType.IQ2_S, + gguf.GGMLQuantizationType.IQ3_XXS, + gguf.GGMLQuantizationType.IQ3_S, + gguf.GGMLQuantizationType.IQ4_XS, + gguf.GGMLQuantizationType.IQ4_NL, } # TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization. # Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add @@ -73,8 +72,7 @@ MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES -def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, - qweight_type: int) -> torch.Tensor: +def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor: # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: return x @ qweight.T @@ -87,8 +85,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) # elif qweight_type in MMQ_QUANT_TYPES: # y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) - # If there is no available MMQ kernel, fallback to dequantize + # If there is no available MMQ kernel, fallback to dequantize elif qweight_type in DEQUANT_TYPES: block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) @@ -98,9 +96,8 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, # Raise an error if the quantization type is not supported. # Might be useful if llama.cpp adds a new quantization type. # Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type. - qweight_type = WeightType(qweight_type) - raise NotImplementedError( - f"Unsupported GGUF quantization type: {qweight_type}") + qweight_type = gguf.GGMLQuantizationType(qweight_type) + raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}") return y @@ -528,12 +525,12 @@ class GGUFLinear(nn.Linear): self.compute_dtype = compute_dtype self.device = device - def forward(self, inputs): + def forward(self, inputs: torch.Tensor): if ops is not None and self.weight.is_cuda and inputs.is_cuda: return self.forward_cuda(inputs) return self.forward_native(inputs) - def forward_native(self, inputs): + def forward_native(self, inputs: torch.Tensor): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) bias = self.bias.to(self.compute_dtype) if self.bias is not None else None @@ -541,12 +538,9 @@ class GGUFLinear(nn.Linear): output = torch.nn.functional.linear(inputs, weight, bias) return output - def forward_cuda(self, inputs): + def forward_cuda(self, inputs: torch.Tensor): quant_type = self.weight.quant_type - orig_shape = inputs.shape - inputs = inputs.view(-1, orig_shape[-1]) output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type) if self.bias is not None: - output = output + self.bias.to(self.compute_dtype) - return output.view(*orig_shape[:-1], -1) - + output += self.bias.to(self.compute_dtype) + return output diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 72f020ec19..72b12badf2 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -76,9 +76,9 @@ from .import_utils import ( is_hpu_available, is_inflect_available, is_invisible_watermark_available, - is_kernels_available, is_k_diffusion_available, is_k_diffusion_version, + is_kernels_available, is_librosa_available, is_matplotlib_available, is_nltk_available,