1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-08-21 17:08:29 +05:30
parent 2a827ec19f
commit df58c8017e
2 changed files with 40 additions and 105 deletions

View File

@@ -20,11 +20,6 @@ if TYPE_CHECKING:
if is_torch_available():
import torch
if is_accelerate_available():
pass
if is_nunchaku_available():
from .utils import replace_with_nunchaku_linear
logger = logging.get_logger(__name__)
@@ -79,13 +74,14 @@ class NunchakuQuantizer(DiffusersQuantizer):
state_dict: Dict[str, Any],
**kwargs,
):
from nunchaku.models.linear import SVDQW4A4Linear
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
return True
return False
# TODO: revisit
# Check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
return False
else:
# We only quantize the weight of nn.Linear
module, _ = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear)
def create_quantized_param(
self,
@@ -112,13 +108,32 @@ class NunchakuQuantizer(DiffusersQuantizer):
module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device))
elif isinstance(module, torch.nn.Linear):
if tensor_name in module._parameters:
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
if tensor_name in module._buffers:
module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device)
# TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
# But we need to have a utility that can take a pretrained param value and quantize it. Not sure
# how to do that yet.
# Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
# way to do it?
is_param = tensor_name in module._parameters
is_buffer = tensor_name in module._buffers
new_module = SVDQW4A4Linear.from_linear(
module, precision=self.quantization_config.precision, rank=self.quantization_config.rank
)
module_name = ".".join(param_name.split(".")[:-1])
if "." in module_name:
parent_name, leaf = module_name.rsplit(".", 1)
parent = model.get_submodule(parent_name)
else:
parent, leaf = model, module_name
new_module = SVDQW4A4Linear.from_linear(module)
setattr(model, param_name, new_module)
# rebind
# this will result into
# AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
if is_param:
new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
elif is_buffer:
new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
setattr(parent, leaf, new_module)
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
@@ -157,24 +172,25 @@ class NunchakuQuantizer(DiffusersQuantizer):
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
# TODO: deal with `device_map`
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
# TODO: revisit
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
# keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
# self.modules_to_not_convert.extend(keys_on_cpu)
# Purge `None`.
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
# in case of diffusion transformer models. For language models and others alike, `lm_head`
# and tied modules are usually kept in FP32.
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
model = replace_with_nunchaku_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model, **kwargs):

View File

@@ -1,81 +0,0 @@
import torch.nn as nn
from ...utils import is_accelerate_available, is_nunchaku_available, logging
if is_accelerate_available():
from accelerate import init_empty_weights
logger = logging.get_logger(__name__)
def _replace_with_nunchaku_linear(
model,
svdq_linear_cls,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
if quantization_config.precision in ["int4", "nvfp4"]:
model._modules[name] = svdq_linear_cls(
in_features,
out_features,
rank=quantization_config.rank,
bias=module.bias is not None,
torch_dtype=module.weight.dtype,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_nunchaku_linear(
module,
svdq_linear_cls,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
if is_nunchaku_available():
from nunchaku.models.linear import SVDQW4A4Linear
model, _ = _replace_with_nunchaku_linear(
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
)
has_been_replaced = any(
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
)
if not has_been_replaced:
logger.warning(
"You are loading your model in the SVDQuant method but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
return model