mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
@@ -17,10 +17,11 @@ import inspect
|
||||
from contextlib import nullcontext
|
||||
|
||||
import gguf
|
||||
from gguf import GGMLQuantizationType as WeightType
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available
|
||||
from ...utils import is_accelerate_available, is_kernels_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -29,6 +30,76 @@ if is_accelerate_available():
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
|
||||
|
||||
can_use_cuda_kernels = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7
|
||||
if can_use_cuda_kernels and is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
ops = get_kernel("Isotr0py/ggml")
|
||||
else:
|
||||
ops = None
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
STANDARD_QUANT_TYPES = {
|
||||
WeightType.Q4_0,
|
||||
WeightType.Q4_1,
|
||||
WeightType.Q5_0,
|
||||
WeightType.Q5_1,
|
||||
WeightType.Q8_0,
|
||||
WeightType.Q8_1,
|
||||
}
|
||||
KQUANT_TYPES = {
|
||||
WeightType.Q2_K,
|
||||
WeightType.Q3_K,
|
||||
WeightType.Q4_K,
|
||||
WeightType.Q5_K,
|
||||
WeightType.Q6_K,
|
||||
}
|
||||
IMATRIX_QUANT_TYPES = {
|
||||
WeightType.IQ1_M,
|
||||
WeightType.IQ1_S,
|
||||
WeightType.IQ2_XXS,
|
||||
WeightType.IQ2_XS,
|
||||
WeightType.IQ2_S,
|
||||
WeightType.IQ3_XXS,
|
||||
WeightType.IQ3_S,
|
||||
WeightType.IQ4_XS,
|
||||
WeightType.IQ4_NL,
|
||||
}
|
||||
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
|
||||
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
|
||||
# MMQ kernel for I-Matrix quantization.
|
||||
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
|
||||
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
|
||||
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
# enable MMVQ in contiguous batching with batch_size=1
|
||||
if qweight_type in MMVQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# Use MMQ Kernel if it's available (standard + k-quants)
|
||||
elif qweight_type in MMQ_QUANT_TYPES:
|
||||
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
|
||||
# If there is no available MMQ kernel, fallback to dequantize
|
||||
elif qweight_type in DEQUANT_TYPES:
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
|
||||
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
|
||||
weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
|
||||
y = x @ weight.T
|
||||
else:
|
||||
# Raise an error if the quantization type is not supported.
|
||||
# Might be useful if llama.cpp adds a new quantization type.
|
||||
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
|
||||
qweight_type = WeightType(qweight_type)
|
||||
raise NotImplementedError(
|
||||
f"Unsupported GGUF quantization type: {qweight_type}")
|
||||
return y
|
||||
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
|
||||
def _create_accelerate_new_hook(old_hook):
|
||||
r"""
|
||||
@@ -451,11 +522,22 @@ class GGUFLinear(nn.Linear):
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device)
|
||||
self.compute_dtype = compute_dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, inputs):
|
||||
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
|
||||
return self.forward_cuda(inputs)
|
||||
return self.forward_native(inputs)
|
||||
|
||||
def forward_native(self, inputs):
|
||||
weight = dequantize_gguf_tensor(self.weight)
|
||||
weight = weight.to(self.compute_dtype)
|
||||
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
|
||||
|
||||
output = torch.nn.functional.linear(inputs, weight, bias)
|
||||
return output
|
||||
|
||||
def forward_cuda(self, inputs):
|
||||
quant_type = self.weight.quant_type
|
||||
return _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ from .import_utils import (
|
||||
is_hpu_available,
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
is_kernels_available,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_librosa_available,
|
||||
|
||||
@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
||||
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
||||
_transformers_available, _transformers_version = _is_package_available("transformers")
|
||||
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
||||
_kernels_available, _kernels_version = _is_package_available("kernels")
|
||||
_inflect_available, _inflect_version = _is_package_available("inflect")
|
||||
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
||||
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
||||
@@ -274,6 +275,10 @@ def is_accelerate_available():
|
||||
return _accelerate_available
|
||||
|
||||
|
||||
def is_kernels_available():
|
||||
return _kernels_available
|
||||
|
||||
|
||||
def is_k_diffusion_available():
|
||||
return _k_diffusion_available
|
||||
|
||||
|
||||
Reference in New Issue
Block a user