mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Single File] Add GGUF support (#9964)
* update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/diffusers/quantizers/gguf/utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update * update * update * update * update * update * update * update * update * update * Update docs/source/en/quantization/gguf.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update * update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -357,6 +357,8 @@ jobs:
|
||||
config:
|
||||
- backend: "bitsandbytes"
|
||||
test_location: "bnb"
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
|
||||
@@ -157,6 +157,8 @@
|
||||
title: Getting Started
|
||||
- local: quantization/bitsandbytes
|
||||
title: bitsandbytes
|
||||
- local: quantization/gguf
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
title: Quantization Methods
|
||||
|
||||
@@ -28,6 +28,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
|
||||
[[autodoc]] BitsAndBytesConfig
|
||||
|
||||
## GGUFQuantizationConfig
|
||||
|
||||
[[autodoc]] GGUFQuantizationConfig
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
70
docs/source/en/quantization/gguf.md
Normal file
70
docs/source/en/quantization/gguf.md
Normal file
@@ -0,0 +1,70 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
-->
|
||||
|
||||
# GGUF
|
||||
|
||||
The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.
|
||||
|
||||
The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.
|
||||
|
||||
Before starting please install gguf in your environment
|
||||
|
||||
```shell
|
||||
pip install -U gguf
|
||||
```
|
||||
|
||||
Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
|
||||
|
||||
When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
|
||||
|
||||
The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade).
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
generator=torch.manual_seed(0),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
- BF16
|
||||
- Q4_0
|
||||
- Q4_1
|
||||
- Q5_0
|
||||
- Q5_1
|
||||
- Q8_0
|
||||
- Q2_K
|
||||
- Q3_K
|
||||
- Q4_K
|
||||
- Q5_K
|
||||
- Q6_K
|
||||
|
||||
@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a
|
||||
|
||||
<Tip>
|
||||
|
||||
Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
|
||||
Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be
|
||||
|
||||
## When to use what?
|
||||
|
||||
Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
|
||||
Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes]()
|
||||
- [TorchAO]()
|
||||
- [GGUF]()
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
@@ -31,7 +31,7 @@ _import_structure = {
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
@@ -569,7 +569,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
|
||||
@@ -17,8 +17,10 @@ import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
@@ -214,6 +216,8 @@ class FromOriginalModelMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
@@ -227,6 +231,12 @@ class FromOriginalModelMixin:
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
@@ -309,8 +319,36 @@ class FromOriginalModelMixin:
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=None,
|
||||
state_dict=diffusers_format_checkpoint,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device=param_device,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
@@ -324,7 +362,11 @@ class FromOriginalModelMixin:
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and hf_quantizer is None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
@@ -81,8 +81,14 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
"sd3": [
|
||||
"joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
],
|
||||
"sd35_large": [
|
||||
"joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
],
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
@@ -542,13 +548,20 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
model_type = "stable_cascade_stage_b"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
|
||||
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
|
||||
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
|
||||
):
|
||||
if "model.diffusion_model.pos_embed" in checkpoint:
|
||||
key = "model.diffusion_model.pos_embed"
|
||||
else:
|
||||
key = "pos_embed"
|
||||
|
||||
if checkpoint[key].shape[1] == 36864:
|
||||
model_type = "sd3"
|
||||
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
|
||||
elif checkpoint[key].shape[1] == 147456:
|
||||
model_type = "sd35_medium"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
|
||||
model_type = "sd35_large"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
@@ -26,6 +27,7 @@ import torch
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
from ..utils import (
|
||||
GGUF_FILE_EXTENSION,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
@@ -33,6 +35,8 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
elif file_extension == GGUF_FILE_EXTENSION:
|
||||
return load_gguf_checkpoint(checkpoint_file)
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||
return torch.load(
|
||||
@@ -211,13 +217,14 @@ def load_model_dict_into_meta(
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
|
||||
# bnb params are flattened.
|
||||
# gguf quants have a different shape based on the type of quantization applied
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
if (
|
||||
is_quantized
|
||||
and hf_quantizer.pre_quantized
|
||||
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
||||
):
|
||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
|
||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
|
||||
else:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
@@ -396,3 +403,78 @@ def _fetch_index_file_legacy(
|
||||
index_file = None
|
||||
|
||||
return index_file
|
||||
|
||||
|
||||
def _gguf_parse_value(_value, data_type):
|
||||
if not isinstance(data_type, list):
|
||||
data_type = [data_type]
|
||||
if len(data_type) == 1:
|
||||
data_type = data_type[0]
|
||||
array_data_type = None
|
||||
else:
|
||||
if data_type[0] != 9:
|
||||
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
|
||||
data_type, array_data_type = data_type
|
||||
|
||||
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
|
||||
_value = int(_value[0])
|
||||
elif data_type in [6, 12]:
|
||||
_value = float(_value[0])
|
||||
elif data_type in [7]:
|
||||
_value = bool(_value[0])
|
||||
elif data_type in [8]:
|
||||
_value = array("B", list(_value)).tobytes().decode()
|
||||
elif data_type in [9]:
|
||||
_value = _gguf_parse_value(_value, array_data_type)
|
||||
return _value
|
||||
|
||||
|
||||
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
"""
|
||||
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
|
||||
attributes.
|
||||
|
||||
Args:
|
||||
gguf_checkpoint_path (`str`):
|
||||
The path the to GGUF file to load
|
||||
return_tensors (`bool`, defaults to `True`):
|
||||
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
|
||||
metadata in memory.
|
||||
"""
|
||||
|
||||
if is_gguf_available() and is_torch_available():
|
||||
import gguf
|
||||
from gguf import GGUFReader
|
||||
|
||||
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
|
||||
else:
|
||||
logger.error(
|
||||
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
|
||||
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
|
||||
)
|
||||
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
|
||||
|
||||
reader = GGUFReader(gguf_checkpoint_path)
|
||||
|
||||
parsed_parameters = {}
|
||||
for tensor in reader.tensors:
|
||||
name = tensor.name
|
||||
quant_type = tensor.tensor_type
|
||||
|
||||
# if the tensor is a torch supported dtype do not use GGUFParameter
|
||||
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
||||
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
|
||||
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
|
||||
raise ValueError(
|
||||
(
|
||||
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
|
||||
"\n\nCurrently the following quantization types are supported: \n\n"
|
||||
f"{_supported_quants_str}"
|
||||
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
|
||||
)
|
||||
)
|
||||
|
||||
weights = torch.from_numpy(tensor.data.copy())
|
||||
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
||||
|
||||
return parsed_parameters
|
||||
|
||||
@@ -1038,14 +1038,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dtype_present_in_args = True
|
||||
break
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_quantized", False):
|
||||
if dtype_present_in_args:
|
||||
raise ValueError(
|
||||
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
|
||||
" desired `dtype` by passing the correct `torch_dtype` argument."
|
||||
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
|
||||
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
|
||||
)
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
|
||||
@@ -524,7 +524,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
|
||||
@@ -15,23 +15,33 @@
|
||||
Adapted from
|
||||
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
|
||||
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
|
||||
from .gguf import GGUFQuantizer
|
||||
from .quantization_config import (
|
||||
BitsAndBytesConfig,
|
||||
GGUFQuantizationConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
TorchAoConfig,
|
||||
)
|
||||
from .torchao import TorchAoHfQuantizer
|
||||
|
||||
|
||||
AUTO_QUANTIZER_MAPPING = {
|
||||
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
|
||||
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
|
||||
"gguf": GGUFQuantizer,
|
||||
"torchao": TorchAoHfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"bitsandbytes_4bit": BitsAndBytesConfig,
|
||||
"bitsandbytes_8bit": BitsAndBytesConfig,
|
||||
"gguf": GGUFQuantizationConfig,
|
||||
"torchao": TorchAoConfig,
|
||||
}
|
||||
|
||||
|
||||
@@ -204,7 +204,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
||||
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
|
||||
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
|
||||
current_param_shape = current_param.shape
|
||||
loaded_param_shape = loaded_param.shape
|
||||
|
||||
n = current_param_shape.numel()
|
||||
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
|
||||
if loaded_param_shape != inferred_shape:
|
||||
|
||||
1
src/diffusers/quantizers/gguf/__init__.py
Normal file
1
src/diffusers/quantizers/gguf/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .gguf_quantizer import GGUFQuantizer
|
||||
159
src/diffusers/quantizers/gguf/gguf_quantizer.py
Normal file
159
src/diffusers/quantizers/gguf/gguf_quantizer.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
from ...utils import (
|
||||
get_module_from_name,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_gguf_available,
|
||||
is_gguf_version,
|
||||
is_torch_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available() and is_gguf_available():
|
||||
import torch
|
||||
|
||||
from .utils import (
|
||||
GGML_QUANT_SIZES,
|
||||
GGUFParameter,
|
||||
_dequantize_gguf_and_restore_linear,
|
||||
_quant_shape_from_byte_shape,
|
||||
_replace_with_gguf_linear,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class GGUFQuantizer(DiffusersQuantizer):
|
||||
use_keep_in_fp32_modules = True
|
||||
|
||||
def __init__(self, quantization_config, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
self.compute_dtype = quantization_config.compute_dtype
|
||||
self.pre_quantized = quantization_config.pre_quantized
|
||||
self.modules_to_not_convert = quantization_config.modules_to_not_convert
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
|
||||
raise ImportError(
|
||||
"Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
|
||||
)
|
||||
if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
|
||||
raise ImportError(
|
||||
"To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
|
||||
)
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
# need more space for buffers that are created during quantization
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if target_dtype != torch.uint8:
|
||||
logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
|
||||
return torch.uint8
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
torch_dtype = self.compute_dtype
|
||||
return torch_dtype
|
||||
|
||||
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
|
||||
loaded_param_shape = loaded_param.shape
|
||||
current_param_shape = current_param.shape
|
||||
quant_type = loaded_param.quant_type
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
|
||||
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
|
||||
if inferred_shape != current_param_shape:
|
||||
raise ValueError(
|
||||
f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def check_if_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: Union["GGUFParameter", "torch.Tensor"],
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if isinstance(param_value, GGUFParameter):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
param_value: Union["GGUFParameter", "torch.Tensor"],
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
state_dict: Optional[Dict[str, Any]] = None,
|
||||
unexpected_keys: Optional[List[str]] = None,
|
||||
):
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
|
||||
if tensor_name in module._parameters:
|
||||
module._parameters[tensor_name] = param_value.to(target_device)
|
||||
if tensor_name in module._buffers:
|
||||
module._buffers[tensor_name] = param_value.to(target_device)
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
device_map,
|
||||
keep_in_fp32_modules: List[str] = [],
|
||||
**kwargs,
|
||||
):
|
||||
state_dict = kwargs.get("state_dict", None)
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
|
||||
|
||||
_replace_with_gguf_linear(
|
||||
model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_trainable(self) -> bool:
|
||||
return False
|
||||
|
||||
def _dequantize(self, model):
|
||||
is_model_on_cpu = model.device.type == "cpu"
|
||||
if is_model_on_cpu:
|
||||
logger.info(
|
||||
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
|
||||
)
|
||||
model.to(torch.cuda.current_device())
|
||||
|
||||
model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
|
||||
if is_model_on_cpu:
|
||||
model.to("cpu")
|
||||
return model
|
||||
456
src/diffusers/quantizers/gguf/utils.py
Normal file
456
src/diffusers/quantizers/gguf/utils.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# Copyright 2024 The HuggingFace Team and City96. All rights reserved.
|
||||
# #
|
||||
# # Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# # you may not use this file except in compliance with the License.
|
||||
# # You may obtain a copy of the License at
|
||||
# #
|
||||
# # http://www.apache.org/licenses/LICENSE-2.0
|
||||
# #
|
||||
# # Unless required by applicable law or agreed to in writing, software
|
||||
# # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# # See the License for the specific language governing permissions and
|
||||
# # limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
from contextlib import nullcontext
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...utils import is_accelerate_available
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
||||
|
||||
|
||||
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
|
||||
def _create_accelerate_new_hook(old_hook):
|
||||
r"""
|
||||
Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
|
||||
https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
|
||||
some changes
|
||||
"""
|
||||
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
|
||||
old_hook_attr = old_hook.__dict__
|
||||
filtered_old_hook_attr = {}
|
||||
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
|
||||
for k in old_hook_attr.keys():
|
||||
if k in old_hook_init_signature.parameters:
|
||||
filtered_old_hook_attr[k] = old_hook_attr[k]
|
||||
new_hook = old_hook_cls(**filtered_old_hook_attr)
|
||||
return new_hook
|
||||
|
||||
|
||||
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
|
||||
def _should_convert_to_gguf(state_dict, prefix):
|
||||
weight_key = prefix + "weight"
|
||||
return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
|
||||
|
||||
has_children = list(model.children())
|
||||
if not has_children:
|
||||
return
|
||||
|
||||
for name, module in model.named_children():
|
||||
module_prefix = prefix + name + "."
|
||||
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)
|
||||
|
||||
if (
|
||||
isinstance(module, nn.Linear)
|
||||
and _should_convert_to_gguf(state_dict, module_prefix)
|
||||
and name not in modules_to_not_convert
|
||||
):
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model._modules[name] = GGUFLinear(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
compute_dtype=compute_dtype,
|
||||
)
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires_grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, GGUFLinear) and name not in modules_to_not_convert:
|
||||
device = module.weight.device
|
||||
bias = getattr(module, "bias", None)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
new_module = nn.Linear(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
device=device,
|
||||
)
|
||||
new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight))
|
||||
if bias is not None:
|
||||
new_module.bias = bias
|
||||
|
||||
# Create a new hook and attach it in case we use accelerate
|
||||
if hasattr(module, "_hf_hook"):
|
||||
old_hook = module._hf_hook
|
||||
new_hook = _create_accelerate_new_hook(old_hook)
|
||||
|
||||
remove_hook_from_module(module)
|
||||
add_hook_to_module(new_module, new_hook)
|
||||
|
||||
new_module.to(device)
|
||||
model._modules[name] = new_module
|
||||
|
||||
has_children = list(module.children())
|
||||
if has_children:
|
||||
_dequantize_gguf_and_restore_linear(module, modules_to_not_convert)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# dequantize operations based on torch ports of GGUF dequantize_functions
|
||||
# from City96
|
||||
# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py
|
||||
|
||||
|
||||
QK_K = 256
|
||||
K_SCALE_SIZE = 12
|
||||
|
||||
|
||||
def to_uint32(x):
|
||||
x = x.view(torch.uint8).to(torch.int32)
|
||||
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
|
||||
|
||||
|
||||
def split_block_dims(blocks, *args):
|
||||
n_max = blocks.shape[1]
|
||||
dims = list(args) + [n_max - sum(args)]
|
||||
return torch.split(blocks, dims, dim=1)
|
||||
|
||||
|
||||
def get_scale_min(scales):
|
||||
n_blocks = scales.shape[0]
|
||||
scales = scales.view(torch.uint8)
|
||||
scales = scales.reshape((n_blocks, 3, 4))
|
||||
|
||||
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
|
||||
|
||||
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
|
||||
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
|
||||
|
||||
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
|
||||
|
||||
|
||||
def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
|
||||
d, x = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
x = x.view(torch.int8)
|
||||
return d * x
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1))
|
||||
|
||||
qs = ql | (qh << 4)
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qh, qs = split_block_dims(blocks, 2, 4)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
qh = to_uint32(qh)
|
||||
|
||||
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
|
||||
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
|
||||
qh = (qh & 1).to(torch.uint8)
|
||||
ql = (ql & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
qs = (ql | (qh << 4)).to(torch.int8) - 16
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, m, qs = split_block_dims(blocks, 2, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
m = m.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape(1, 1, 2, 1)
|
||||
qs = (qs & 0x0F).reshape(n_blocks, -1)
|
||||
|
||||
return (d * qs) + m
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, qs = split_block_dims(blocks, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
|
||||
[0, 4], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 1, 2, 1))
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
|
||||
return d * qs
|
||||
|
||||
|
||||
def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
(
|
||||
ql,
|
||||
qh,
|
||||
scales,
|
||||
d,
|
||||
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
|
||||
|
||||
scales = scales.view(torch.int8).to(dtype)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
|
||||
q = (ql | (qh << 4)).to(torch.int8) - 32
|
||||
q = q.reshape((n_blocks, QK_K // 16, -1))
|
||||
|
||||
return (d * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
|
||||
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
|
||||
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
|
||||
q = ql | (qh << 4)
|
||||
|
||||
return (d * q - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
sc, m = get_scale_min(scales)
|
||||
|
||||
d = (d * sc).reshape((n_blocks, -1, 1))
|
||||
dm = (dmin * m).reshape((n_blocks, -1, 1))
|
||||
|
||||
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 2, 1)
|
||||
)
|
||||
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
|
||||
|
||||
return (d * qs - dm).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
|
||||
lscales, hscales = scales[:, :8], scales[:, 8:]
|
||||
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 2, 1)
|
||||
)
|
||||
lscales = lscales.reshape((n_blocks, 16))
|
||||
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
|
||||
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
|
||||
).reshape((1, 4, 1))
|
||||
hscales = hscales.reshape((n_blocks, 16))
|
||||
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
|
||||
scales = scales.to(torch.int8) - 32
|
||||
|
||||
dl = (d * scales).reshape((n_blocks, 16, 1))
|
||||
|
||||
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 4, 1)
|
||||
)
|
||||
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
|
||||
(1, 1, 8, 1)
|
||||
)
|
||||
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
|
||||
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
|
||||
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
|
||||
|
||||
return (dl * q).reshape((n_blocks, QK_K))
|
||||
|
||||
|
||||
def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
|
||||
n_blocks = blocks.shape[0]
|
||||
|
||||
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
|
||||
d = d.view(torch.float16).to(dtype)
|
||||
dmin = dmin.view(torch.float16).to(dtype)
|
||||
|
||||
# (n_blocks, 16, 1)
|
||||
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
|
||||
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
|
||||
|
||||
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
|
||||
|
||||
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
|
||||
qs = qs.reshape((n_blocks, QK_K // 16, 16))
|
||||
qs = dl * qs - ml
|
||||
|
||||
return qs.reshape((n_blocks, -1))
|
||||
|
||||
|
||||
def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
|
||||
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
|
||||
|
||||
|
||||
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
|
||||
dequantize_functions = {
|
||||
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
|
||||
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
|
||||
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
|
||||
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
|
||||
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
|
||||
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
|
||||
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
|
||||
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
|
||||
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
|
||||
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
|
||||
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
|
||||
}
|
||||
SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys())
|
||||
|
||||
|
||||
def _quant_shape_from_byte_shape(shape, type_size, block_size):
|
||||
return (*shape[:-1], shape[-1] // type_size * block_size)
|
||||
|
||||
|
||||
def dequantize_gguf_tensor(tensor):
|
||||
if not hasattr(tensor, "quant_type"):
|
||||
return tensor
|
||||
|
||||
quant_type = tensor.quant_type
|
||||
dequant_fn = dequantize_functions[quant_type]
|
||||
|
||||
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
||||
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
|
||||
|
||||
n_blocks = tensor.numel() // type_size
|
||||
blocks = tensor.reshape((n_blocks, type_size))
|
||||
|
||||
dequant = dequant_fn(blocks, block_size, type_size)
|
||||
dequant = dequant.reshape(shape)
|
||||
|
||||
return dequant.as_tensor()
|
||||
|
||||
|
||||
class GGUFParameter(torch.nn.Parameter):
|
||||
def __new__(cls, data, requires_grad=False, quant_type=None):
|
||||
data = data if data is not None else torch.empty(0)
|
||||
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
self.quant_type = quant_type
|
||||
|
||||
return self
|
||||
|
||||
def as_tensor(self):
|
||||
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
result = super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
# When converting from original format checkpoints we often use splits, cats etc on tensors
|
||||
# this method ensures that the returned tensor type from those operations remains GGUFParameter
|
||||
# so that we preserve quant_type information
|
||||
quant_type = None
|
||||
for arg in args:
|
||||
if isinstance(arg, list) and (arg[0], GGUFParameter):
|
||||
quant_type = arg[0].quant_type
|
||||
break
|
||||
if isinstance(arg, GGUFParameter):
|
||||
quant_type = arg.quant_type
|
||||
break
|
||||
if isinstance(result, torch.Tensor):
|
||||
return cls(result, quant_type=quant_type)
|
||||
# Handle tuples and lists
|
||||
elif isinstance(result, (tuple, list)):
|
||||
# Preserve the original type (tuple or list)
|
||||
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
|
||||
return type(result)(wrapped)
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
class GGUFLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=False,
|
||||
compute_dtype=None,
|
||||
device=None,
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device)
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def forward(self, inputs):
|
||||
weight = dequantize_gguf_tensor(self.weight)
|
||||
weight = weight.to(self.compute_dtype)
|
||||
bias = self.bias.to(self.compute_dtype)
|
||||
|
||||
output = torch.nn.functional.linear(inputs, weight, bias)
|
||||
return output
|
||||
@@ -43,6 +43,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
class QuantizationMethod(str, Enum):
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GGUF = "gguf"
|
||||
TORCHAO = "torchao"
|
||||
|
||||
|
||||
@@ -394,6 +395,29 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
return serializable_config_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class GGUFQuantizationConfig(QuantizationConfigMixin):
|
||||
"""This is a config class for GGUF Quantization techniques.
|
||||
|
||||
Args:
|
||||
compute_dtype: (`torch.dtype`, defaults to `torch.float32`):
|
||||
This sets the computational type which might be different than the input type. For example, inputs might be
|
||||
fp32, but computation can be set to bf16 for speedups.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, compute_dtype: Optional["torch.dtype"] = None):
|
||||
self.quant_method = QuantizationMethod.GGUF
|
||||
self.compute_dtype = compute_dtype
|
||||
self.pre_quantized = True
|
||||
|
||||
# TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints.
|
||||
self.modules_to_not_convert = None
|
||||
|
||||
if self.compute_dtype is None:
|
||||
self.compute_dtype = torch.float32
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchAoConfig(QuantizationConfigMixin):
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
|
||||
@@ -23,6 +23,7 @@ from .constants import (
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
GGUF_FILE_EXTENSION,
|
||||
HF_MODULES_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
MIN_PEFT_VERSION,
|
||||
@@ -66,6 +67,8 @@ from .import_utils import (
|
||||
is_bs4_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_gguf_available,
|
||||
is_gguf_version,
|
||||
is_google_colab,
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
|
||||
@@ -34,6 +34,7 @@ ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
|
||||
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
||||
GGUF_FILE_EXTENSION = "gguf"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
|
||||
@@ -339,6 +339,14 @@ if _imageio_available:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_imageio_available = False
|
||||
|
||||
_is_gguf_available = importlib.util.find_spec("gguf") is not None
|
||||
if _is_gguf_available:
|
||||
try:
|
||||
_gguf_version = importlib_metadata.version("gguf")
|
||||
logger.debug(f"Successfully import gguf version {_gguf_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_is_gguf_available = False
|
||||
|
||||
|
||||
_is_torchao_available = importlib.util.find_spec("torchao") is not None
|
||||
if _is_torchao_available:
|
||||
@@ -469,6 +477,10 @@ def is_imageio_available():
|
||||
return _imageio_available
|
||||
|
||||
|
||||
def is_gguf_available():
|
||||
return _is_gguf_available
|
||||
|
||||
|
||||
def is_torchao_available():
|
||||
return _is_torchao_available
|
||||
|
||||
@@ -607,8 +619,13 @@ IMAGEIO_IMPORT_ERROR = """
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
GGUF_IMPORT_ERROR = """
|
||||
{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf`
|
||||
"""
|
||||
|
||||
TORCHAO_IMPORT_ERROR = """
|
||||
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao`
|
||||
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
|
||||
torchao`
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
@@ -636,6 +653,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
|
||||
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
|
||||
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
|
||||
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
||||
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
@@ -793,6 +811,21 @@ def is_bitsandbytes_version(operation: str, version: str):
|
||||
return compare_versions(parse(_bitsandbytes_version), operation, version)
|
||||
|
||||
|
||||
def is_gguf_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _is_gguf_available:
|
||||
return False
|
||||
return compare_versions(parse(_gguf_version), operation, version)
|
||||
|
||||
|
||||
def is_k_diffusion_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current k-diffusion version to a given reference with an operation.
|
||||
|
||||
@@ -32,6 +32,7 @@ from .import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_compel_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
@@ -477,6 +478,18 @@ def require_bitsandbytes_version_greater(bnb_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_gguf_version_greater_or_equal(gguf_version):
|
||||
def decorator(test_case):
|
||||
correct_gguf_version = is_gguf_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("gguf")).base_version
|
||||
) >= version.parse(gguf_version)
|
||||
return unittest.skipUnless(
|
||||
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_torchao_version_greater(torchao_version):
|
||||
def decorator(test_case):
|
||||
correct_torchao_version = is_torchao_available() and version.parse(
|
||||
|
||||
379
tests/quantization/gguf/test_gguf.py
Normal file
379
tests/quantization/gguf/test_gguf.py
Normal file
@@ -0,0 +1,379 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers import (
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel,
|
||||
GGUFQuantizationConfig,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
is_gguf_available,
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerate,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_gguf_version_greater_or_equal,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_gguf_available():
|
||||
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
|
||||
|
||||
|
||||
@nightly
|
||||
@require_big_gpu_with_torch_cuda
|
||||
@require_accelerate
|
||||
@require_gguf_version_greater_or_equal("0.10.0")
|
||||
class GGUFSingleFileTesterMixin:
|
||||
ckpt_path = None
|
||||
model_cls = None
|
||||
torch_dtype = torch.bfloat16
|
||||
expected_memory_use_in_gb = 5
|
||||
|
||||
def test_gguf_parameters(self):
|
||||
quant_storage_type = torch.uint8
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
|
||||
|
||||
for param_name, param in model.named_parameters():
|
||||
if isinstance(param, GGUFParameter):
|
||||
assert hasattr(param, "quant_type")
|
||||
assert param.dtype == quant_storage_type
|
||||
|
||||
def test_gguf_linear_layers(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
|
||||
assert module.weight.dtype == torch.uint8
|
||||
assert module.bias.dtype == torch.float32
|
||||
|
||||
def test_gguf_memory_usage(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
|
||||
model = self.model_cls.from_single_file(
|
||||
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
|
||||
)
|
||||
model.to("cuda")
|
||||
assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
model(**inputs)
|
||||
max_memory = torch.cuda.max_memory_allocated()
|
||||
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb
|
||||
|
||||
def test_keep_modules_in_fp32(self):
|
||||
r"""
|
||||
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
|
||||
Also ensures if inference works.
|
||||
"""
|
||||
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
|
||||
self.model_cls._keep_in_fp32_modules = ["proj_out"]
|
||||
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if name in model._keep_in_fp32_modules:
|
||||
assert module.weight.dtype == torch.float32
|
||||
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
|
||||
|
||||
def test_dtype_assignment(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype`
|
||||
model.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device` and `dtype`
|
||||
model.to(device="cuda:0", dtype=torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
model.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a cast
|
||||
model.half()
|
||||
|
||||
# This should work
|
||||
model.to("cuda")
|
||||
|
||||
def test_dequantize_model(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
|
||||
model.dequantize()
|
||||
|
||||
def _check_for_gguf_linear(model):
|
||||
has_children = list(model.children())
|
||||
if not has_children:
|
||||
return
|
||||
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, nn.Linear):
|
||||
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
|
||||
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
|
||||
|
||||
for name, module in model.named_children():
|
||||
_check_for_gguf_linear(module)
|
||||
|
||||
|
||||
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
torch_dtype = torch.bfloat16
|
||||
model_cls = FluxTransformer2DModel
|
||||
expected_memory_use_in_gb = 5
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": torch.randn(
|
||||
(1, 512, 4096),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"pooled_projections": torch.randn(
|
||||
(1, 768),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
def test_pipeline_inference(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
transformer = self.model_cls.from_single_file(
|
||||
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "a cat holding a sign that says hello"
|
||||
output = pipe(
|
||||
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
|
||||
).images[0]
|
||||
output_slice = output[:3, :3, :].flatten()
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.47265625,
|
||||
0.43359375,
|
||||
0.359375,
|
||||
0.47070312,
|
||||
0.421875,
|
||||
0.34375,
|
||||
0.46875,
|
||||
0.421875,
|
||||
0.34765625,
|
||||
0.46484375,
|
||||
0.421875,
|
||||
0.34179688,
|
||||
0.47070312,
|
||||
0.42578125,
|
||||
0.34570312,
|
||||
0.46875,
|
||||
0.42578125,
|
||||
0.3515625,
|
||||
0.45507812,
|
||||
0.4140625,
|
||||
0.33984375,
|
||||
0.4609375,
|
||||
0.41796875,
|
||||
0.34375,
|
||||
0.45898438,
|
||||
0.41796875,
|
||||
0.34375,
|
||||
]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
|
||||
torch_dtype = torch.bfloat16
|
||||
model_cls = SD3Transformer2DModel
|
||||
expected_memory_use_in_gb = 5
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": torch.randn(
|
||||
(1, 512, 4096),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"pooled_projections": torch.randn(
|
||||
(1, 2048),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
def test_pipeline_inference(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
transformer = self.model_cls.from_single_file(
|
||||
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large", transformer=transformer, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "a cat holding a sign that says hello"
|
||||
output = pipe(
|
||||
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
|
||||
).images[0]
|
||||
output_slice = output[:3, :3, :].flatten()
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.17578125,
|
||||
0.27539062,
|
||||
0.27734375,
|
||||
0.11914062,
|
||||
0.26953125,
|
||||
0.25390625,
|
||||
0.109375,
|
||||
0.25390625,
|
||||
0.25,
|
||||
0.15039062,
|
||||
0.26171875,
|
||||
0.28515625,
|
||||
0.13671875,
|
||||
0.27734375,
|
||||
0.28515625,
|
||||
0.12109375,
|
||||
0.26757812,
|
||||
0.265625,
|
||||
0.16210938,
|
||||
0.29882812,
|
||||
0.28515625,
|
||||
0.15625,
|
||||
0.30664062,
|
||||
0.27734375,
|
||||
0.14648438,
|
||||
0.29296875,
|
||||
0.26953125,
|
||||
]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
|
||||
class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
|
||||
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
|
||||
torch_dtype = torch.bfloat16
|
||||
model_cls = SD3Transformer2DModel
|
||||
expected_memory_use_in_gb = 2
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {
|
||||
"hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
|
||||
torch_device, self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": torch.randn(
|
||||
(1, 512, 4096),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"pooled_projections": torch.randn(
|
||||
(1, 2048),
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).to(torch_device, self.torch_dtype),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
def test_pipeline_inference(self):
|
||||
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
|
||||
transformer = self.model_cls.from_single_file(
|
||||
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-medium", transformer=transformer, torch_dtype=self.torch_dtype
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "a cat holding a sign that says hello"
|
||||
output = pipe(
|
||||
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
|
||||
).images[0]
|
||||
output_slice = output[:3, :3, :].flatten()
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.625,
|
||||
0.6171875,
|
||||
0.609375,
|
||||
0.65625,
|
||||
0.65234375,
|
||||
0.640625,
|
||||
0.6484375,
|
||||
0.640625,
|
||||
0.625,
|
||||
0.6484375,
|
||||
0.63671875,
|
||||
0.6484375,
|
||||
0.66796875,
|
||||
0.65625,
|
||||
0.65234375,
|
||||
0.6640625,
|
||||
0.6484375,
|
||||
0.6328125,
|
||||
0.6640625,
|
||||
0.6484375,
|
||||
0.640625,
|
||||
0.67578125,
|
||||
0.66015625,
|
||||
0.62109375,
|
||||
0.671875,
|
||||
0.65625,
|
||||
0.62109375,
|
||||
]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
|
||||
assert max_diff < 1e-4
|
||||
Reference in New Issue
Block a user