mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
131 lines
4.8 KiB
Python
131 lines
4.8 KiB
Python
import sys
|
|
import diffusers
|
|
from installer import install, log
|
|
|
|
|
|
bnb = None
|
|
quanto = None
|
|
ao = None
|
|
|
|
|
|
def create_bnb_config(kwargs = None, allow_bnb: bool = True):
|
|
from modules import shared, devices
|
|
if len(shared.opts.bnb_quantization) > 0 and allow_bnb:
|
|
if 'Model' in shared.opts.bnb_quantization:
|
|
load_bnb()
|
|
if bnb is None:
|
|
return kwargs
|
|
bnb_config = diffusers.BitsAndBytesConfig(
|
|
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
|
|
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
|
|
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
|
|
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
|
|
bnb_4bit_compute_dtype=devices.dtype
|
|
)
|
|
shared.log.debug(f'Quantization: module=all type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
|
|
if kwargs is None:
|
|
return bnb_config
|
|
else:
|
|
kwargs['quantization_config'] = bnb_config
|
|
return kwargs
|
|
return kwargs
|
|
|
|
|
|
def create_ao_config(kwargs = None, allow_ao: bool = True):
|
|
from modules import shared
|
|
if len(shared.opts.torchao_quantization) > 0 and shared.opts.torchao_quantization_mode == 'pre' and allow_ao:
|
|
if 'Model' in shared.opts.torchao_quantization:
|
|
load_torchao()
|
|
if ao is None:
|
|
return kwargs
|
|
diffusers.utils.import_utils.is_torchao_available = lambda: True
|
|
ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type)
|
|
shared.log.debug(f'Quantization: module=all type=torchao dtype={shared.opts.torchao_quantization_type}')
|
|
if kwargs is None:
|
|
return ao_config
|
|
else:
|
|
kwargs['quantization_config'] = ao_config
|
|
return kwargs
|
|
return kwargs
|
|
|
|
|
|
def load_torchao(msg='', silent=False):
|
|
global ao # pylint: disable=global-statement
|
|
if ao is not None:
|
|
return ao
|
|
install('torchao==0.7.0', quiet=True)
|
|
try:
|
|
import torchao
|
|
ao = torchao
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access
|
|
return ao
|
|
except Exception as e:
|
|
if len(msg) > 0:
|
|
log.error(f"{msg} failed to import torchao: {e}")
|
|
ao = None
|
|
if not silent:
|
|
raise
|
|
return None
|
|
|
|
|
|
def load_bnb(msg='', silent=False):
|
|
global bnb # pylint: disable=global-statement
|
|
if bnb is not None:
|
|
return bnb
|
|
install('bitsandbytes==0.45.0', quiet=True)
|
|
try:
|
|
import bitsandbytes
|
|
bnb = bitsandbytes
|
|
diffusers.utils.import_utils._bitsandbytes_available = True # pylint: disable=protected-access
|
|
diffusers.utils.import_utils._bitsandbytes_version = '0.43.3' # pylint: disable=protected-access
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
log.debug(f'Quantization: type=bitsandbytes version={bnb.__version__} fn={fn}') # pylint: disable=protected-access
|
|
return bnb
|
|
except Exception as e:
|
|
if len(msg) > 0:
|
|
log.error(f"{msg} failed to import bitsandbytes: {e}")
|
|
bnb = None
|
|
if not silent:
|
|
raise
|
|
return None
|
|
|
|
|
|
def load_quanto(msg='', silent=False):
|
|
from modules import shared
|
|
global quanto # pylint: disable=global-statement
|
|
if quanto is not None:
|
|
return quanto
|
|
install('optimum-quanto==0.2.6', quiet=True)
|
|
try:
|
|
from optimum import quanto as optimum_quanto # pylint: disable=no-name-in-module
|
|
quanto = optimum_quanto
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access
|
|
if shared.opts.diffusers_offload_mode in {'balanced', 'sequential'}:
|
|
shared.log.error(f'Quantization: type=quanto offload={shared.opts.diffusers_offload_mode} not supported')
|
|
return quanto
|
|
except Exception as e:
|
|
if len(msg) > 0:
|
|
log.error(f"{msg} failed to import optimum.quanto: {e}")
|
|
quanto = None
|
|
if not silent:
|
|
raise
|
|
return None
|
|
|
|
|
|
def get_quant(name):
|
|
if "qint8" in name.lower():
|
|
return 'qint8'
|
|
if "qint4" in name.lower():
|
|
return 'qint4'
|
|
if "fp8" in name.lower():
|
|
return 'fp8'
|
|
if "fp4" in name.lower():
|
|
return 'fp4'
|
|
if "nf4" in name.lower():
|
|
return 'nf4'
|
|
if name.endswith('.gguf'):
|
|
return 'gguf'
|
|
return 'none'
|