diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index 113d9a4ba1..c61ac26928 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -164,6 +164,11 @@ class NunchakuQuantizer(DiffusersQuantizer): self.modules_to_not_convert = [self.modules_to_not_convert] self.modules_to_not_convert.extend(keep_in_fp32_modules) + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] model = replace_with_nunchaku_linear( model, diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py index af36d7e638..b8b015e6ee 100644 --- a/src/diffusers/quantizers/nunchaku/utils.py +++ b/src/diffusers/quantizers/nunchaku/utils.py @@ -5,7 +5,6 @@ from ...utils import is_accelerate_available, is_nunchaku_available, logging if is_accelerate_available(): from accelerate import init_empty_weights - logger = logging.get_logger(__name__) @@ -40,7 +39,7 @@ def _replace_with_nunchaku_linear( out_features, rank=quantization_config.rank, bias=module.bias is not None, - dtype=model.dtype, + torch_dtype=module.weight.dtype, ) has_been_replaced = True # Store the module class in case we need to transpose the weight later @@ -50,6 +49,7 @@ def _replace_with_nunchaku_linear( if len(list(module.children())) > 0: _, has_been_replaced = _replace_with_nunchaku_linear( module, + svdq_linear_cls, modules_to_not_convert, current_key_name, quantization_config, @@ -64,7 +64,9 @@ def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key if is_nunchaku_available(): from nunchaku.models.linear import SVDQW4A4Linear - model, _ = _replace_with_nunchaku_linear(model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config) + model, _ = _replace_with_nunchaku_linear( + model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config + ) has_been_replaced = any( isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 057a5aec71..acde80879a 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -750,6 +750,7 @@ class NunchakuConfig(QuantizationConfigMixin): ): self.quant_method = QuantizationMethod.NUNCHAKU self.precision = precision + self.rank = rank self.group_size = self.group_size_map[precision] self.modules_to_not_convert = modules_to_not_convert @@ -763,6 +764,8 @@ class NunchakuConfig(QuantizationConfigMixin): if self.precision not in accpeted_precision: raise ValueError(f"Only supported precision in {accpeted_precision} but found {self.precision}") + # TODO: should there be a check for rank? + # Copied from diffusers.quantizers.bitsandbytes.quantization_config.BitsandBytesConfig.to_diff_dict with BitsandBytesConfig->NunchakuConfig def to_diff_dict(self) -> Dict[str, Any]: """