mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -164,6 +164,11 @@ class NunchakuQuantizer(DiffusersQuantizer):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
# Purge `None`.
|
||||
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
|
||||
# in case of diffusion transformer models. For language models and others alike, `lm_head`
|
||||
# and tied modules are usually kept in FP32.
|
||||
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
|
||||
|
||||
model = replace_with_nunchaku_linear(
|
||||
model,
|
||||
|
||||
@@ -5,7 +5,6 @@ from ...utils import is_accelerate_available, is_nunchaku_available, logging
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -40,7 +39,7 @@ def _replace_with_nunchaku_linear(
|
||||
out_features,
|
||||
rank=quantization_config.rank,
|
||||
bias=module.bias is not None,
|
||||
dtype=model.dtype,
|
||||
torch_dtype=module.weight.dtype,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
@@ -50,6 +49,7 @@ def _replace_with_nunchaku_linear(
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_nunchaku_linear(
|
||||
module,
|
||||
svdq_linear_cls,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
@@ -64,7 +64,9 @@ def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key
|
||||
if is_nunchaku_available():
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
model, _ = _replace_with_nunchaku_linear(model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config)
|
||||
model, _ = _replace_with_nunchaku_linear(
|
||||
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
|
||||
)
|
||||
|
||||
has_been_replaced = any(
|
||||
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
|
||||
|
||||
@@ -750,6 +750,7 @@ class NunchakuConfig(QuantizationConfigMixin):
|
||||
):
|
||||
self.quant_method = QuantizationMethod.NUNCHAKU
|
||||
self.precision = precision
|
||||
self.rank = rank
|
||||
self.group_size = self.group_size_map[precision]
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@@ -763,6 +764,8 @@ class NunchakuConfig(QuantizationConfigMixin):
|
||||
if self.precision not in accpeted_precision:
|
||||
raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}")
|
||||
|
||||
# TODO: should there be a check for rank?
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig
|
||||
def to_diff_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user