From df58c8017e77833be570b3377906b3de2e3ac1f7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 21 Aug 2025 17:08:29 +0530 Subject: [PATCH] up --- .../quantizers/nunchaku/nunchaku_quantizer.py | 64 +++++++++------ src/diffusers/quantizers/nunchaku/utils.py | 81 ------------------- 2 files changed, 40 insertions(+), 105 deletions(-) delete mode 100644 src/diffusers/quantizers/nunchaku/utils.py diff --git a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py index 7cee0bccc3..dbb68a68e8 100644 --- a/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py +++ b/src/diffusers/quantizers/nunchaku/nunchaku_quantizer.py @@ -20,11 +20,6 @@ if TYPE_CHECKING: if is_torch_available(): import torch -if is_accelerate_available(): - pass - -if is_nunchaku_available(): - from .utils import replace_with_nunchaku_linear logger = logging.get_logger(__name__) @@ -79,13 +74,14 @@ class NunchakuQuantizer(DiffusersQuantizer): state_dict: Dict[str, Any], **kwargs, ): - from nunchaku.models.linear import SVDQW4A4Linear - - module, tensor_name = get_module_from_name(model, param_name) - if self.pre_quantized and isinstance(module, SVDQW4A4Linear): - return True - - return False + # TODO: revisit + # Check if the param_name is not in self.modules_to_not_convert + if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): + return False + else: + # We only quantize the weight of nn.Linear + module, _ = get_module_from_name(model, param_name) + return isinstance(module, torch.nn.Linear) def create_quantized_param( self, @@ -112,13 +108,32 @@ class NunchakuQuantizer(DiffusersQuantizer): module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device)) elif isinstance(module, torch.nn.Linear): - if tensor_name in module._parameters: - module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) - if tensor_name in module._buffers: - module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device) + # TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module. + # But we need to have a utility that can take a pretrained param value and quantize it. Not sure + # how to do that yet. + # Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better + # way to do it? + is_param = tensor_name in module._parameters + is_buffer = tensor_name in module._buffers + new_module = SVDQW4A4Linear.from_linear( + module, precision=self.quantization_config.precision, rank=self.quantization_config.rank + ) + module_name = ".".join(param_name.split(".")[:-1]) + if "." in module_name: + parent_name, leaf = module_name.rsplit(".", 1) + parent = model.get_submodule(parent_name) + else: + parent, leaf = model, module_name - new_module = SVDQW4A4Linear.from_linear(module) - setattr(model, param_name, new_module) + # rebind + # this will result into + # AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'. + if is_param: + new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + elif is_buffer: + new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + + setattr(parent, leaf, new_module) def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: max_memory = {key: val * 0.90 for key, val in max_memory.items()} @@ -157,24 +172,25 @@ class NunchakuQuantizer(DiffusersQuantizer): keep_in_fp32_modules: List[str] = [], **kwargs, ): - # TODO: deal with `device_map` self.modules_to_not_convert = self.quantization_config.modules_to_not_convert if not isinstance(self.modules_to_not_convert, list): self.modules_to_not_convert = [self.modules_to_not_convert] self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # TODO: revisit + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + # if isinstance(device_map, dict) and len(device_map.keys()) > 1: + # keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + # self.modules_to_not_convert.extend(keys_on_cpu) + # Purge `None`. # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 # in case of diffusion transformer models. For language models and others alike, `lm_head` # and tied modules are usually kept in FP32. self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] - model = replace_with_nunchaku_linear( - model, - modules_to_not_convert=self.modules_to_not_convert, - quantization_config=self.quantization_config, - ) model.config.quantization_config = self.quantization_config def _process_model_after_weight_loading(self, model, **kwargs): diff --git a/src/diffusers/quantizers/nunchaku/utils.py b/src/diffusers/quantizers/nunchaku/utils.py deleted file mode 100644 index b8b015e6ee..0000000000 --- a/src/diffusers/quantizers/nunchaku/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import torch.nn as nn - -from ...utils import is_accelerate_available, is_nunchaku_available, logging - - -if is_accelerate_available(): - from accelerate import init_empty_weights - - -logger = logging.get_logger(__name__) - - -def _replace_with_nunchaku_linear( - model, - svdq_linear_cls, - modules_to_not_convert=None, - current_key_name=None, - quantization_config=None, - has_been_replaced=False, -): - for name, module in model.named_children(): - if current_key_name is None: - current_key_name = [] - current_key_name.append(name) - - if isinstance(module, nn.Linear) and name not in modules_to_not_convert: - # Check if the current key is not in the `modules_to_not_convert` - current_key_name_str = ".".join(current_key_name) - if not any( - (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert - ): - with init_empty_weights(): - in_features = module.in_features - out_features = module.out_features - - if quantization_config.precision in ["int4", "nvfp4"]: - model._modules[name] = svdq_linear_cls( - in_features, - out_features, - rank=quantization_config.rank, - bias=module.bias is not None, - torch_dtype=module.weight.dtype, - ) - has_been_replaced = True - # Store the module class in case we need to transpose the weight later - model._modules[name].source_cls = type(module) - # Force requires grad to False to avoid unexpected errors - model._modules[name].requires_grad_(False) - if len(list(module.children())) > 0: - _, has_been_replaced = _replace_with_nunchaku_linear( - module, - svdq_linear_cls, - modules_to_not_convert, - current_key_name, - quantization_config, - has_been_replaced=has_been_replaced, - ) - # Remove the last key for recursion - current_key_name.pop(-1) - return model, has_been_replaced - - -def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): - if is_nunchaku_available(): - from nunchaku.models.linear import SVDQW4A4Linear - - model, _ = _replace_with_nunchaku_linear( - model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config - ) - - has_been_replaced = any( - isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules() - ) - if not has_been_replaced: - logger.warning( - "You are loading your model in the SVDQuant method but no linear modules were found in your model." - " Please double check your model architecture, or submit an issue on github if you think this is" - " a bug." - ) - - return model