1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00

OpenVINO Quantization support with NNCF

This commit is contained in:
Disty0
2024-01-25 20:22:57 +03:00
parent d1bb51eee0
commit db6fd95351
4 changed files with 78 additions and 44 deletions

View File

@@ -209,6 +209,9 @@ As of this release, default backend is set to **diffusers** as its more feature
- disable 1024x1024 workaround if the GPU supports 64 bit
- fix lock-ups at very high resolutions
- **OpenVINO**, thanks @disty0
- **quantization support with NNCF**
run 8 bit directly on your GPU without autocast
enable *OpenVINO Quantize Models with NNCF* from *Compute Settings*
- **4-bit support with NNCF**
enable *Compress Model weights with NNCF* from *Compute Settings* and set a 4-bit NNCF mode
select both CPU and GPU from the device selection if you want to use 4-bit or 8-bit modes on GPU

View File

@@ -463,7 +463,7 @@ def check_torch():
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cpu')
install(os.environ.get('OPENVINO_PACKAGE', 'openvino==2023.3.0'), 'openvino')
install('onnxruntime-openvino', 'onnxruntime-openvino', ignore=True) # TODO openvino: numpy version conflicts with tensorflow and doesn't support Python 3.11
install('nncf==2.7.0', 'nncf')
install('nncf==2.8.0', 'nncf')
os.environ.setdefault('PYTORCH_TRACING_MODE', 'TORCHFX')
os.environ.setdefault('NEOReadDebugKeys', '1')
os.environ.setdefault('ClDeviceGlobalMemSizeAvailablePercent', '100')
@@ -533,7 +533,7 @@ def check_torch():
log.debug(f'Cannot install xformers package: {e}')
if opts.get('cuda_compile_backend', '') == 'hidet':
install('hidet', 'hidet')
if opts.get('nncf_compress_weights', False):
if opts.get('nncf_compress_weights', False) and not args.use_openvino:
install('nncf==2.7.0', 'nncf')
if args.profile:
print_profile(pr, 'Torch')

View File

@@ -20,8 +20,7 @@ import functools
from modules import shared, devices, sd_models
NNCFNodeName = str
def get_node_by_name(self, name: NNCFNodeName) -> nncf.common.graph.NNCFNode:
def get_node_by_name(self, name: str) -> nncf.common.graph.NNCFNode:
node_ids = self._node_name_to_node_id_map.get(name, None)
if node_ids is None:
raise RuntimeError("Could not find a node {} in NNCFGraph!".format(name))
@@ -179,13 +178,14 @@ def execute_cached(compiled_model, *args):
result = [torch.from_numpy(res[out]) for out in compiled_model.outputs]
return result
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_name=""):
def openvino_compile(gm: GraphModule, *example_inputs, model_hash_str: str = None, file_name=""):
core = Core()
device = get_device()
cache_root = shared.opts.openvino_cache_path
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant
if file_name is not None and os.path.isfile(file_name + ".xml") and os.path.isfile(file_name + ".bin"):
om = core.read_model(file_name + ".xml")
@@ -195,7 +195,7 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na
input_shapes = []
input_types = []
for input_data in args:
for input_data in example_inputs:
if isinstance(input_data, torch.SymInt):
input_types.append(torch.SymInt)
input_shapes.append(1)
@@ -213,7 +213,7 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na
serialize(om, file_name + ".xml", file_name + ".bin")
if (shared.compiled_model_state.cn_model != []):
f = open(file_name + ".txt", "w")
for input_data in args:
for input_data in example_inputs:
f.write(str(input_data.size()))
f.write("\n")
f.close()
@@ -229,42 +229,6 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, file_na
torch.bool: Type.boolean
}
for idx, input_data in enumerate(args):
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types()
if shared.opts.nncf_compress_weights and not dont_use_nncf:
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
om = nncf.compress_weights(om)
else:
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=8, ratio=shared.opts.nncf_compress_weights_raito)
if model_hash_str is not None:
core.set_property({'CACHE_DIR': cache_root + '/blob'})
dont_use_nncf = False
dont_use_4bit_nncf = False
compiled_model = core.compile_model(om, device)
return compiled_model
def openvino_compile_cached_model(cached_model_path, *example_inputs):
core = Core()
om = core.read_model(cached_model_path + ".xml")
global dont_use_4bit_nncf
global dont_use_nncf
dtype_mapping = {
torch.float32: Type.f32,
torch.float64: Type.f64,
torch.float16: Type.f16,
torch.int64: Type.i64,
torch.int32: Type.i32,
torch.uint8: Type.u8,
torch.int8: Type.i8,
torch.bool: Type.boolean
}
for idx, input_data in enumerate(example_inputs):
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
@@ -274,9 +238,70 @@ def openvino_compile_cached_model(cached_model_path, *example_inputs):
om = nncf.compress_weights(om)
else:
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=8, ratio=shared.opts.nncf_compress_weights_raito)
if shared.opts.nncf_quantize and not dont_use_quant:
new_inputs = []
for idx, _ in enumerate(example_inputs):
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
new_inputs = [new_inputs]
if shared.opts.nncf_quant_mode == "INT8":
nncf.quantize(om, nncf.Dataset(new_inputs))
else:
nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quant_mode),
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
if model_hash_str is not None:
core.set_property({'CACHE_DIR': cache_root + '/blob'})
dont_use_nncf = False
dont_use_quant = False
dont_use_4bit_nncf = False
compiled_model = core.compile_model(om, device)
return compiled_model
def openvino_compile_cached_model(cached_model_path, *example_inputs):
core = Core()
om = core.read_model(cached_model_path + ".xml")
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant
dtype_mapping = {
torch.float32: Type.f32,
torch.float64: Type.f64,
torch.float16: Type.f16,
torch.int64: Type.i64,
torch.int32: Type.i32,
torch.uint8: Type.u8,
torch.int8: Type.i8,
torch.bool: Type.boolean
}
for idx, input_data in enumerate(example_inputs):
om.inputs[idx].get_node().set_element_type(dtype_mapping[input_data.dtype])
om.inputs[idx].get_node().set_partial_shape(PartialShape(list(input_data.shape)))
om.validate_nodes_and_infer_types()
if shared.opts.nncf_compress_weights and not dont_use_nncf:
if dont_use_4bit_nncf or shared.opts.nncf_compress_weights_mode == "INT8":
om = nncf.compress_weights(om)
else:
om = nncf.compress_weights(om, mode=getattr(nncf.CompressWeightsMode, shared.opts.nncf_compress_weights_mode), group_size=8, ratio=shared.opts.nncf_compress_weights_raito)
if shared.opts.nncf_quantize and not dont_use_quant:
new_inputs = []
for idx, _ in enumerate(example_inputs):
new_inputs.append(example_inputs[idx].detach().cpu().numpy())
new_inputs = [new_inputs]
if shared.opts.nncf_quant_mode == "INT8":
nncf.quantize(om, nncf.Dataset(new_inputs))
else:
nncf.quantize(om, nncf.Dataset(new_inputs), mode=getattr(nncf.QuantizationMode, shared.opts.nncf_quant_mode),
advanced_parameters=nncf.quantization.advanced_parameters.AdvancedQuantizationParameters(
overflow_fix=nncf.quantization.advanced_parameters.OverflowFix.DISABLE, backend_params=None))
core.set_property({'CACHE_DIR': shared.opts.openvino_cache_path + '/blob'})
dont_use_nncf = False
dont_use_quant = False
dont_use_4bit_nncf = False
compiled_model = core.compile_model(om, get_device())
@@ -366,10 +391,12 @@ def get_subgraph_type(tensor):
def openvino_fx(subgraph, example_inputs):
global dont_use_4bit_nncf
global dont_use_nncf
global dont_use_quant
global subgraph_type
dont_use_4bit_nncf = False
dont_use_nncf = False
dont_use_quant = False
dont_use_faketensors = False
executor_parameters = None
inputs_reversed = False
@@ -386,6 +413,7 @@ def openvino_fx(subgraph, example_inputs):
dont_use_4bit_nncf = True
dont_use_nncf = bool("VAE" not in shared.opts.nncf_compress_weights)
dont_use_quant = bool("VAE" not in shared.opts.nncf_quantize)
# SD 1.5 / SDXL Text Encoder
elif (subgraph_type[0] is torch.nn.modules.sparse.Embedding and
@@ -395,6 +423,7 @@ def openvino_fx(subgraph, example_inputs):
dont_use_faketensors = True
dont_use_nncf = bool("Text Encoder" not in shared.opts.nncf_compress_weights)
dont_use_quant = bool("Text Encoder" not in shared.opts.nncf_quantize)
if not shared.opts.openvino_disable_model_caching:
os.environ.setdefault('OPENVINO_TORCH_MODEL_CACHING', "1")

View File

@@ -378,7 +378,9 @@ options_templates.update(options_section(('cuda', "Compute Settings"), {
"openvino_sep": OptionInfo("<h2>OpenVINO</h2>", "", gr.HTML, {"visible": cmd_opts.use_openvino}),
"openvino_devices": OptionInfo([], "OpenVINO devices to use", gr.CheckboxGroup, {"choices": get_openvino_device_list() if cmd_opts.use_openvino else [], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_mode": OptionInfo("INT8", "OpenVINO compress mode for NNCF", gr.Radio, {"choices": ['INT8', 'INT4_SYM', 'INT4_ASYM', 'NF4'], "visible": cmd_opts.use_openvino}),
"nncf_quantize": OptionInfo([], "OpenVINO Quantize Models with NNCF", gr.CheckboxGroup, {"choices": ["Model", "VAE", "Text Encoder"], "visible": cmd_opts.use_openvino}),
"nncf_quant_mode": OptionInfo("INT8", "OpenVINO Quantization Mode with NNCF", gr.Radio, {"choices": ['INT8', 'FP8_E4M3', 'FP8_E5M2'], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_mode": OptionInfo("INT8", "OpenVINO compress mode for NNCF", gr.Radio, {"choices": ['INT8', 'INT8_SYM', 'INT4_ASYM', 'INT4_SYM', 'NF4'], "visible": cmd_opts.use_openvino}),
"nncf_compress_weights_raito": OptionInfo(1.0, "OpenVINO compress ratio for NNCF", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": cmd_opts.use_openvino}),
"openvino_disable_model_caching": OptionInfo(False, "OpenVINO disable model caching", gr.Checkbox, {"visible": cmd_opts.use_openvino}),
}))