1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Single File] Add GGUF support (#9964)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/gguf/utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/gguf.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Dhruv Nair
2024-12-17 16:09:37 +05:30
committed by GitHub
parent f9d5a9324d
commit e24941b2a7
22 changed files with 1321 additions and 21 deletions

View File

@@ -357,6 +357,8 @@ jobs:
config:
- backend: "bitsandbytes"
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
runs-on:
group: aws-g6e-xlarge-plus
container:

View File

@@ -157,6 +157,8 @@
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao
title: torchao
title: Quantization Methods

View File

@@ -28,6 +28,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
[[autodoc]] BitsAndBytesConfig
## GGUFQuantizationConfig
[[autodoc]] GGUFQuantizationConfig
## TorchAoConfig
[[autodoc]] TorchAoConfig

View File

@@ -0,0 +1,70 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# GGUF
The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.
The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.
Before starting please install gguf in your environment
```shell
pip install -U gguf
```
Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade).
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
ckpt_path = (
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt).images[0]
image.save("flux-gguf.png")
```
## Supported Quantization Types
- BF16
- Q4_0
- Q4_1
- Q5_0
- Q5_1
- Q8_0
- Q2_K
- Q3_K
- Q4_K
- Q5_K
- Q6_K

View File

@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a
<Tip>
Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
</Tip>
@@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be
## When to use what?
Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
Diffusers currently supports the following quantization methods.
- [BitsandBytes]()
- [TorchAO]()
- [GGUF]()
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.

View File

@@ -31,7 +31,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
@@ -569,7 +569,7 @@ else:
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
try:
if not is_onnx_available():

View File

@@ -17,8 +17,10 @@ import re
from contextlib import nullcontext
from typing import Optional
import torch
from huggingface_hub.utils import validate_hf_hub_args
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
@@ -214,6 +216,8 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
@@ -227,6 +231,12 @@ class FromOriginalModelMixin:
local_files_only=local_files_only,
revision=revision,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
else:
hf_quantizer = None
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
@@ -309,8 +319,36 @@ class FromOriginalModelMixin:
with ctx():
model = cls.from_config(diffusers_model_config)
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]
else:
keep_in_fp32_modules = []
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
state_dict=diffusers_format_checkpoint,
keep_in_fp32_modules=keep_in_fp32_modules,
)
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
param_device = torch.device(device) if device else torch.device("cpu")
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
)
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -324,7 +362,11 @@ class FromOriginalModelMixin:
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
if torch_dtype is not None:
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer
if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)
model.eval()

View File

@@ -81,8 +81,14 @@ CHECKPOINT_KEY_NAMES = {
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
"sd3": [
"joint_blocks.0.context_block.adaLN_modulation.1.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
],
"sd35_large": [
"joint_blocks.37.x_block.mlp.fc1.weight",
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
],
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -542,13 +548,20 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "stable_cascade_stage_b"
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
):
if "model.diffusion_model.pos_embed" in checkpoint:
key = "model.diffusion_model.pos_embed"
else:
key = "pos_embed"
if checkpoint[key].shape[1] == 36864:
model_type = "sd3"
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
elif checkpoint[key].shape[1] == 147456:
model_type = "sd35_medium"
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
model_type = "sd35_large"
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:

View File

@@ -17,6 +17,7 @@
import importlib
import inspect
import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
@@ -26,6 +27,7 @@ import torch
from huggingface_hub.utils import EntryNotFoundError
from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
@@ -33,6 +35,8 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_gguf_available,
is_torch_available,
is_torch_version,
logging,
)
@@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
@@ -211,13 +217,14 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype
# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if (
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
@@ -396,3 +403,78 @@ def _fetch_index_file_legacy(
index_file = None
return index_file
def _gguf_parse_value(_value, data_type):
if not isinstance(data_type, list):
data_type = [data_type]
if len(data_type) == 1:
data_type = data_type[0]
array_data_type = None
else:
if data_type[0] != 9:
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
data_type, array_data_type = data_type
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
_value = int(_value[0])
elif data_type in [6, 12]:
_value = float(_value[0])
elif data_type in [7]:
_value = bool(_value[0])
elif data_type in [8]:
_value = array("B", list(_value)).tobytes().decode()
elif data_type in [9]:
_value = _gguf_parse_value(_value, array_data_type)
return _value
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
"""
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
attributes.
Args:
gguf_checkpoint_path (`str`):
The path the to GGUF file to load
return_tensors (`bool`, defaults to `True`):
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
metadata in memory.
"""
if is_gguf_available() and is_torch_available():
import gguf
from gguf import GGUFReader
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
else:
logger.error(
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
)
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
reader = GGUFReader(gguf_checkpoint_path)
parsed_parameters = {}
for tensor in reader.tensors:
name = tensor.name
quant_type = tensor.tensor_type
# if the tensor is a torch supported dtype do not use GGUFParameter
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
raise ValueError(
(
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
"\n\nCurrently the following quantization types are supported: \n\n"
f"{_supported_quants_str}"
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
)
)
weights = torch.from_numpy(tensor.data.copy())
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters

View File

@@ -1038,14 +1038,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dtype_present_in_args = True
break
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_quantized", False):
if dtype_present_in_args:
raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
" desired `dtype` by passing the correct `torch_dtype` argument."
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
)
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"

View File

@@ -524,7 +524,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):

View File

@@ -15,23 +15,33 @@
Adapted from
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
"""
import warnings
from typing import Dict, Optional, Union
from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod, TorchAoConfig
from .gguf import GGUFQuantizer
from .quantization_config import (
BitsAndBytesConfig,
GGUFQuantizationConfig,
QuantizationConfigMixin,
QuantizationMethod,
TorchAoConfig,
)
from .torchao import TorchAoHfQuantizer
AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"gguf": GGUFQuantizer,
"torchao": TorchAoHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"gguf": GGUFQuantizationConfig,
"torchao": TorchAoConfig,
}

View File

@@ -204,7 +204,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
module._parameters[tensor_name] = new_value
def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
current_param_shape = current_param.shape
loaded_param_shape = loaded_param.shape
n = current_param_shape.numel()
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
if loaded_param_shape != inferred_shape:

View File

@@ -0,0 +1 @@
from .gguf_quantizer import GGUFQuantizer

View File

@@ -0,0 +1,159 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
from ...utils import (
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
is_gguf_available,
is_gguf_version,
is_torch_available,
logging,
)
if is_torch_available() and is_gguf_available():
import torch
from .utils import (
GGML_QUANT_SIZES,
GGUFParameter,
_dequantize_gguf_and_restore_linear,
_quant_shape_from_byte_shape,
_replace_with_gguf_linear,
)
logger = logging.get_logger(__name__)
class GGUFQuantizer(DiffusersQuantizer):
use_keep_in_fp32_modules = True
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = quantization_config.pre_quantized
self.modules_to_not_convert = quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
"Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
)
if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
raise ImportError(
"To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
)
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
# need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if target_dtype != torch.uint8:
logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
return torch.uint8
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = self.compute_dtype
return torch_dtype
def check_quantized_param_shape(self, param_name, current_param, loaded_param):
loaded_param_shape = loaded_param.shape
current_param_shape = current_param.shape
quant_type = loaded_param.quant_type
block_size, type_size = GGML_QUANT_SIZES[quant_type]
inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
if inferred_shape != current_param_shape:
raise ValueError(
f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
)
return True
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: Union["GGUFParameter", "torch.Tensor"],
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if isinstance(param_value, GGUFParameter):
return True
return False
def create_quantized_param(
self,
model: "ModelMixin",
param_value: Union["GGUFParameter", "torch.Tensor"],
param_name: str,
target_device: "torch.device",
state_dict: Optional[Dict[str, Any]] = None,
unexpected_keys: Optional[List[str]] = None,
):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
if tensor_name in module._parameters:
module._parameters[tensor_name] = param_value.to(target_device)
if tensor_name in module._buffers:
module._buffers[tensor_name] = param_value.to(target_device)
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
state_dict = kwargs.get("state_dict", None)
self.modules_to_not_convert.extend(keep_in_fp32_modules)
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
_replace_with_gguf_linear(
model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
)
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
return model
@property
def is_serializable(self):
return False
@property
def is_trainable(self) -> bool:
return False
def _dequantize(self, model):
is_model_on_cpu = model.device.type == "cpu"
if is_model_on_cpu:
logger.info(
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
)
model.to(torch.cuda.current_device())
model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
if is_model_on_cpu:
model.to("cpu")
return model

View File

@@ -0,0 +1,456 @@
# Copyright 2024 The HuggingFace Team and City96. All rights reserved.
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# # http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.
import inspect
from contextlib import nullcontext
import gguf
import torch
import torch.nn as nn
from ...utils import is_accelerate_available
if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
def _create_accelerate_new_hook(old_hook):
r"""
Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
some changes
"""
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
old_hook_attr = old_hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
for k in old_hook_attr.keys():
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = old_hook_attr[k]
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[]):
def _should_convert_to_gguf(state_dict, prefix):
weight_key = prefix + "weight"
return weight_key in state_dict and isinstance(state_dict[weight_key], GGUFParameter)
has_children = list(model.children())
if not has_children:
return
for name, module in model.named_children():
module_prefix = prefix + name + "."
_replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)
if (
isinstance(module, nn.Linear)
and _should_convert_to_gguf(state_dict, module_prefix)
and name not in modules_to_not_convert
):
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
model._modules[name] = GGUFLinear(
module.in_features,
module.out_features,
module.bias is not None,
compute_dtype=compute_dtype,
)
model._modules[name].source_cls = type(module)
# Force requires_grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
return model
def _dequantize_gguf_and_restore_linear(model, modules_to_not_convert=[]):
for name, module in model.named_children():
if isinstance(module, GGUFLinear) and name not in modules_to_not_convert:
device = module.weight.device
bias = getattr(module, "bias", None)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
new_module = nn.Linear(
module.in_features,
module.out_features,
module.bias is not None,
device=device,
)
new_module.weight = nn.Parameter(dequantize_gguf_tensor(module.weight))
if bias is not None:
new_module.bias = bias
# Create a new hook and attach it in case we use accelerate
if hasattr(module, "_hf_hook"):
old_hook = module._hf_hook
new_hook = _create_accelerate_new_hook(old_hook)
remove_hook_from_module(module)
add_hook_to_module(new_module, new_hook)
new_module.to(device)
model._modules[name] = new_module
has_children = list(module.children())
if has_children:
_dequantize_gguf_and_restore_linear(module, modules_to_not_convert)
return model
# dequantize operations based on torch ports of GGUF dequantize_functions
# from City96
# more info: https://github.com/city96/ComfyUI-GGUF/blob/main/dequant.py
QK_K = 256
K_SCALE_SIZE = 12
def to_uint32(x):
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
def split_block_dims(blocks, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
def get_scale_min(scales):
n_blocks = scales.shape[0]
scales = scales.view(torch.uint8)
scales = scales.reshape((n_blocks, 3, 4))
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
x = x.view(torch.int8)
return d * x
def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape((n_blocks, -1))
qs = ql | (qh << 4)
return (d * qs) + m
def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qh, qs = split_block_dims(blocks, 2, 4)
d = d.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qs = split_block_dims(blocks, 2, 2)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
return (d * qs) + m
def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor(
[0, 4], device=d.device, dtype=torch.uint8
).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
(
ql,
qh,
scales,
d,
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
scales = scales.view(torch.int8).to(dtype)
d = d.view(torch.float16).to(dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
q = (ql | (qh << 4)).to(torch.int8) - 32
q = q.reshape((n_blocks, QK_K // 16, -1))
return (d * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
q = ql | (qh << 4)
return (d * q - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 2, 1)
)
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
d = d.view(torch.float16).to(dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(
(1, 2, 1)
)
lscales = lscales.reshape((n_blocks, 16))
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor(
[0, 2, 4, 6], device=d.device, dtype=torch.uint8
).reshape((1, 4, 1))
hscales = hscales.reshape((n_blocks, 16))
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
scales = scales.to(torch.int8) - 32
dl = (d * scales).reshape((n_blocks, 16, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape(
(1, 1, 4, 1)
)
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.arange(0, 8, device=d.device, dtype=torch.uint8).reshape(
(1, 1, 8, 1)
)
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
return (dl * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
qs = qs.reshape((n_blocks, QK_K // 16, 16))
qs = dl * qs - ml
return qs.reshape((n_blocks, -1))
def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
GGML_QUANT_SIZES = gguf.GGML_QUANT_SIZES
dequantize_functions = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
}
SUPPORTED_GGUF_QUANT_TYPES = list(dequantize_functions.keys())
def _quant_shape_from_byte_shape(shape, type_size, block_size):
return (*shape[:-1], shape[-1] // type_size * block_size)
def dequantize_gguf_tensor(tensor):
if not hasattr(tensor, "quant_type"):
return tensor
quant_type = tensor.quant_type
dequant_fn = dequantize_functions[quant_type]
block_size, type_size = GGML_QUANT_SIZES[quant_type]
tensor = tensor.view(torch.uint8)
shape = _quant_shape_from_byte_shape(tensor.shape, type_size, block_size)
n_blocks = tensor.numel() // type_size
blocks = tensor.reshape((n_blocks, type_size))
dequant = dequant_fn(blocks, block_size, type_size)
dequant = dequant.reshape(shape)
return dequant.as_tensor()
class GGUFParameter(torch.nn.Parameter):
def __new__(cls, data, requires_grad=False, quant_type=None):
data = data if data is not None else torch.empty(0)
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.quant_type = quant_type
return self
def as_tensor(self):
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
result = super().__torch_function__(func, types, args, kwargs)
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
quant_type = None
for arg in args:
if isinstance(arg, list) and (arg[0], GGUFParameter):
quant_type = arg[0].quant_type
break
if isinstance(arg, GGUFParameter):
quant_type = arg.quant_type
break
if isinstance(result, torch.Tensor):
return cls(result, quant_type=quant_type)
# Handle tuples and lists
elif isinstance(result, (tuple, list)):
# Preserve the original type (tuple or list)
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
return type(result)(wrapped)
else:
return result
class GGUFLinear(nn.Linear):
def __init__(
self,
in_features,
out_features,
bias=False,
compute_dtype=None,
device=None,
) -> None:
super().__init__(in_features, out_features, bias, device)
self.compute_dtype = compute_dtype
def forward(self, inputs):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
bias = self.bias.to(self.compute_dtype)
output = torch.nn.functional.linear(inputs, weight, bias)
return output

View File

@@ -43,6 +43,7 @@ logger = logging.get_logger(__name__)
class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
GGUF = "gguf"
TORCHAO = "torchao"
@@ -394,6 +395,29 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
return serializable_config_dict
@dataclass
class GGUFQuantizationConfig(QuantizationConfigMixin):
"""This is a config class for GGUF Quantization techniques.
Args:
compute_dtype: (`torch.dtype`, defaults to `torch.float32`):
This sets the computational type which might be different than the input type. For example, inputs might be
fp32, but computation can be set to bf16 for speedups.
"""
def __init__(self, compute_dtype: Optional["torch.dtype"] = None):
self.quant_method = QuantizationMethod.GGUF
self.compute_dtype = compute_dtype
self.pre_quantized = True
# TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints.
self.modules_to_not_convert = None
if self.compute_dtype is None:
self.compute_dtype = torch.float32
@dataclass
class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.

View File

@@ -23,6 +23,7 @@ from .constants import (
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION,
@@ -66,6 +67,8 @@ from .import_utils import (
is_bs4_available,
is_flax_available,
is_ftfy_available,
is_gguf_available,
is_gguf_version,
is_google_colab,
is_inflect_available,
is_invisible_watermark_available,

View File

@@ -34,6 +34,7 @@ ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors"
GGUF_FILE_EXTENSION = "gguf"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"

View File

@@ -339,6 +339,14 @@ if _imageio_available:
except importlib_metadata.PackageNotFoundError:
_imageio_available = False
_is_gguf_available = importlib.util.find_spec("gguf") is not None
if _is_gguf_available:
try:
_gguf_version = importlib_metadata.version("gguf")
logger.debug(f"Successfully import gguf version {_gguf_version}")
except importlib_metadata.PackageNotFoundError:
_is_gguf_available = False
_is_torchao_available = importlib.util.find_spec("torchao") is not None
if _is_torchao_available:
@@ -469,6 +477,10 @@ def is_imageio_available():
return _imageio_available
def is_gguf_available():
return _is_gguf_available
def is_torchao_available():
return _is_torchao_available
@@ -607,8 +619,13 @@ IMAGEIO_IMPORT_ERROR = """
"""
# docstyle-ignore
GGUF_IMPORT_ERROR = """
{0} requires the gguf library but it was not found in your environment. You can install it with pip: `pip install gguf`
"""
TORCHAO_IMPORT_ERROR = """
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao`
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
torchao`
"""
BACKENDS_MAPPING = OrderedDict(
@@ -636,6 +653,7 @@ BACKENDS_MAPPING = OrderedDict(
("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)),
("imageio", (is_imageio_available, IMAGEIO_IMPORT_ERROR)),
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
]
)
@@ -793,6 +811,21 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _is_gguf_available:
return False
return compare_versions(parse(_gguf_version), operation, version)
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.

View File

@@ -32,6 +32,7 @@ from .import_utils import (
is_bitsandbytes_available,
is_compel_available,
is_flax_available,
is_gguf_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -477,6 +478,18 @@ def require_bitsandbytes_version_greater(bnb_version):
return decorator
def require_gguf_version_greater_or_equal(gguf_version):
def decorator(test_case):
correct_gguf_version = is_gguf_available() and version.parse(
version.parse(importlib.metadata.version("gguf")).base_version
) >= version.parse(gguf_version)
return unittest.skipUnless(
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
)(test_case)
return decorator
def require_torchao_version_greater(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(

View File

@@ -0,0 +1,379 @@
import gc
import unittest
import numpy as np
import torch
import torch.nn as nn
from diffusers import (
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils.testing_utils import (
is_gguf_available,
nightly,
numpy_cosine_similarity_distance,
require_accelerate,
require_big_gpu_with_torch_cuda,
require_gguf_version_greater_or_equal,
torch_device,
)
if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
@nightly
@require_big_gpu_with_torch_cuda
@require_accelerate
@require_gguf_version_greater_or_equal("0.10.0")
class GGUFSingleFileTesterMixin:
ckpt_path = None
model_cls = None
torch_dtype = torch.bfloat16
expected_memory_use_in_gb = 5
def test_gguf_parameters(self):
quant_storage_type = torch.uint8
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
for param_name, param in model.named_parameters():
if isinstance(param, GGUFParameter):
assert hasattr(param, "quant_type")
assert param.dtype == quant_storage_type
def test_gguf_linear_layers(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
assert module.weight.dtype == torch.uint8
assert module.bias.dtype == torch.float32
def test_gguf_memory_usage(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
model = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
model.to("cuda")
assert (model.get_memory_footprint() / 1024**3) < self.expected_memory_use_in_gb
inputs = self.get_dummy_inputs()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
with torch.no_grad():
model(**inputs)
max_memory = torch.cuda.max_memory_allocated()
assert (max_memory / 1024**3) < self.expected_memory_use_in_gb
def test_keep_modules_in_fp32(self):
r"""
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
Also ensures if inference works.
"""
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules
self.model_cls._keep_in_fp32_modules = ["proj_out"]
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if name in model._keep_in_fp32_modules:
assert module.weight.dtype == torch.float32
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules
def test_dtype_assignment(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
with self.assertRaises(ValueError):
# Tries with a `dtype`
model.to(torch.float16)
with self.assertRaises(ValueError):
# Tries with a `device` and `dtype`
model.to(device="cuda:0", dtype=torch.float16)
with self.assertRaises(ValueError):
# Tries with a cast
model.float()
with self.assertRaises(ValueError):
# Tries with a cast
model.half()
# This should work
model.to("cuda")
def test_dequantize_model(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
model = self.model_cls.from_single_file(self.ckpt_path, quantization_config=quantization_config)
model.dequantize()
def _check_for_gguf_linear(model):
has_children = list(model.children())
if not has_children:
return
for name, module in model.named_children():
if isinstance(module, nn.Linear):
assert not isinstance(module, GGUFLinear), f"{name} is still GGUFLinear"
assert not isinstance(module.weight, GGUFParameter), f"{name} weight is still GGUFParameter"
for name, module in model.named_children():
_check_for_gguf_linear(module)
class FluxGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
torch_dtype = torch.bfloat16
model_cls = FluxTransformer2DModel
expected_memory_use_in_gb = 5
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_projections": torch.randn(
(1, 768),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
}
def test_pipeline_inference(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
transformer = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
)
pipe.enable_model_cpu_offload()
prompt = "a cat holding a sign that says hello"
output = pipe(
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
).images[0]
output_slice = output[:3, :3, :].flatten()
expected_slice = np.array(
[
0.47265625,
0.43359375,
0.359375,
0.47070312,
0.421875,
0.34375,
0.46875,
0.421875,
0.34765625,
0.46484375,
0.421875,
0.34179688,
0.47070312,
0.42578125,
0.34570312,
0.46875,
0.42578125,
0.3515625,
0.45507812,
0.4140625,
0.33984375,
0.4609375,
0.41796875,
0.34375,
0.45898438,
0.41796875,
0.34375,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-large-gguf/blob/main/sd3.5_large-Q4_0.gguf"
torch_dtype = torch.bfloat16
model_cls = SD3Transformer2DModel
expected_memory_use_in_gb = 5
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_projections": torch.randn(
(1, 2048),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
def test_pipeline_inference(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
transformer = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-large", transformer=transformer, torch_dtype=self.torch_dtype
)
pipe.enable_model_cpu_offload()
prompt = "a cat holding a sign that says hello"
output = pipe(
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
).images[0]
output_slice = output[:3, :3, :].flatten()
expected_slice = np.array(
[
0.17578125,
0.27539062,
0.27734375,
0.11914062,
0.26953125,
0.25390625,
0.109375,
0.25390625,
0.25,
0.15039062,
0.26171875,
0.28515625,
0.13671875,
0.27734375,
0.28515625,
0.12109375,
0.26757812,
0.265625,
0.16210938,
0.29882812,
0.28515625,
0.15625,
0.30664062,
0.27734375,
0.14648438,
0.29296875,
0.26953125,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
class SD35MediumGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
ckpt_path = "https://huggingface.co/city96/stable-diffusion-3.5-medium-gguf/blob/main/sd3.5_medium-Q3_K_M.gguf"
torch_dtype = torch.bfloat16
model_cls = SD3Transformer2DModel
expected_memory_use_in_gb = 2
def setUp(self):
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
def get_dummy_inputs(self):
return {
"hidden_states": torch.randn((1, 16, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
torch_device, self.torch_dtype
),
"encoder_hidden_states": torch.randn(
(1, 512, 4096),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"pooled_projections": torch.randn(
(1, 2048),
generator=torch.Generator("cpu").manual_seed(0),
).to(torch_device, self.torch_dtype),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
def test_pipeline_inference(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
transformer = self.model_cls.from_single_file(
self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
)
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium", transformer=transformer, torch_dtype=self.torch_dtype
)
pipe.enable_model_cpu_offload()
prompt = "a cat holding a sign that says hello"
output = pipe(
prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
).images[0]
output_slice = output[:3, :3, :].flatten()
expected_slice = np.array(
[
0.625,
0.6171875,
0.609375,
0.65625,
0.65234375,
0.640625,
0.6484375,
0.640625,
0.625,
0.6484375,
0.63671875,
0.6484375,
0.66796875,
0.65625,
0.65234375,
0.6640625,
0.6484375,
0.6328125,
0.6640625,
0.6484375,
0.640625,
0.67578125,
0.66015625,
0.62109375,
0.671875,
0.65625,
0.62109375,
]
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4