1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-08-21 14:39:58 +05:30
parent 9e0caa7afc
commit 5d08150a2e
3 changed files with 13 additions and 3 deletions

View File

@@ -164,6 +164,11 @@ class NunchakuQuantizer(DiffusersQuantizer):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
# Purge `None`.
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
# in case of diffusion transformer models. For language models and others alike, `lm_head`
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
model = replace_with_nunchaku_linear(
model,

View File

@@ -5,7 +5,6 @@ from ...utils import is_accelerate_available, is_nunchaku_available, logging
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
@@ -40,7 +39,7 @@ def _replace_with_nunchaku_linear(
out_features,
rank=quantization_config.rank,
bias=module.bias is not None,
dtype=model.dtype,
torch_dtype=module.weight.dtype,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
@@ -50,6 +49,7 @@ def _replace_with_nunchaku_linear(
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_nunchaku_linear(
module,
svdq_linear_cls,
modules_to_not_convert,
current_key_name,
quantization_config,
@@ -64,7 +64,9 @@ def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key
if is_nunchaku_available():
from nunchaku.models.linear import SVDQW4A4Linear
model, _ = _replace_with_nunchaku_linear(model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config)
model, _ = _replace_with_nunchaku_linear(
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
)
has_been_replaced = any(
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()

View File

@@ -750,6 +750,7 @@ class NunchakuConfig(QuantizationConfigMixin):
):
self.quant_method = QuantizationMethod.NUNCHAKU
self.precision = precision
self.rank = rank
self.group_size = self.group_size_map[precision]
self.modules_to_not_convert = modules_to_not_convert
@@ -763,6 +764,8 @@ class NunchakuConfig(QuantizationConfigMixin):
if self.precision not in accpeted_precision:
raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}")
# TODO: should there be a check for rank?
# Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig
def to_diff_dict(self) -> Dict[str, Any]:
"""