mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -20,11 +20,6 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
if is_nunchaku_available():
|
||||
from .utils import replace_with_nunchaku_linear
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -79,13 +74,14 @@ class NunchakuQuantizer(DiffusersQuantizer):
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if self.pre_quantized and isinstance(module, SVDQW4A4Linear):
|
||||
return True
|
||||
|
||||
return False
|
||||
# TODO: revisit
|
||||
# Check if the param_name is not in self.modules_to_not_convert
|
||||
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
|
||||
return False
|
||||
else:
|
||||
# We only quantize the weight of nn.Linear
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
return isinstance(module, torch.nn.Linear)
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
@@ -112,13 +108,32 @@ class NunchakuQuantizer(DiffusersQuantizer):
|
||||
module._buffers[tensor_name] = torch.nn.Parameter(param_value.to(target_device))
|
||||
|
||||
elif isinstance(module, torch.nn.Linear):
|
||||
if tensor_name in module._parameters:
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
if tensor_name in module._buffers:
|
||||
module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(target_device)
|
||||
# TODO: this returns an `SVDQW4A4Linear` layer initialized from the corresponding `linear` module.
|
||||
# But we need to have a utility that can take a pretrained param value and quantize it. Not sure
|
||||
# how to do that yet.
|
||||
# Essentially, we need something like `bnb.nn.Params4bit.from_prequantized`. Or is there a better
|
||||
# way to do it?
|
||||
is_param = tensor_name in module._parameters
|
||||
is_buffer = tensor_name in module._buffers
|
||||
new_module = SVDQW4A4Linear.from_linear(
|
||||
module, precision=self.quantization_config.precision, rank=self.quantization_config.rank
|
||||
)
|
||||
module_name = ".".join(param_name.split(".")[:-1])
|
||||
if "." in module_name:
|
||||
parent_name, leaf = module_name.rsplit(".", 1)
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent, leaf = model, module_name
|
||||
|
||||
new_module = SVDQW4A4Linear.from_linear(module)
|
||||
setattr(model, param_name, new_module)
|
||||
# rebind
|
||||
# this will result into
|
||||
# AttributeError: 'SVDQW4A4Linear' object has no attribute 'weight'. Did you mean: 'qweight'.
|
||||
if is_param:
|
||||
new_module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
elif is_buffer:
|
||||
new_module._buffers[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
|
||||
setattr(parent, leaf, new_module)
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
@@ -157,24 +172,25 @@ class NunchakuQuantizer(DiffusersQuantizer):
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
# TODO: deal with `device_map`
|
||||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
|
||||
# TODO: revisit
|
||||
# Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
|
||||
# if isinstance(device_map, dict) and len(device_map.keys()) > 1:
|
||||
# keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
|
||||
# self.modules_to_not_convert.extend(keys_on_cpu)
|
||||
|
||||
# Purge `None`.
|
||||
# Unlike `transformers`, we don't know if we should always keep certain modules in FP32
|
||||
# in case of diffusion transformer models. For language models and others alike, `lm_head`
|
||||
# and tied modules are usually kept in FP32.
|
||||
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
|
||||
|
||||
model = replace_with_nunchaku_linear(
|
||||
model,
|
||||
modules_to_not_convert=self.modules_to_not_convert,
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available, is_nunchaku_available, logging
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _replace_with_nunchaku_linear(
|
||||
model,
|
||||
svdq_linear_cls,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
quantization_config=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
current_key_name_str = ".".join(current_key_name)
|
||||
if not any(
|
||||
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
|
||||
):
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
if quantization_config.precision in ["int4", "nvfp4"]:
|
||||
model._modules[name] = svdq_linear_cls(
|
||||
in_features,
|
||||
out_features,
|
||||
rank=quantization_config.rank,
|
||||
bias=module.bias is not None,
|
||||
torch_dtype=module.weight.dtype,
|
||||
)
|
||||
has_been_replaced = True
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = _replace_with_nunchaku_linear(
|
||||
module,
|
||||
svdq_linear_cls,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
||||
|
||||
|
||||
def replace_with_nunchaku_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||
if is_nunchaku_available():
|
||||
from nunchaku.models.linear import SVDQW4A4Linear
|
||||
|
||||
model, _ = _replace_with_nunchaku_linear(
|
||||
model, SVDQW4A4Linear, modules_to_not_convert, current_key_name, quantization_config
|
||||
)
|
||||
|
||||
has_been_replaced = any(
|
||||
isinstance(replaced_module, SVDQW4A4Linear) for _, replaced_module in model.named_modules()
|
||||
)
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading your model in the SVDQuant method but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
return model
|
||||
Reference in New Issue
Block a user