1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
Files
diffusers/docs/source/en/quantization/bitsandbytes.md
Sayak Paul b821f006d0 [Quantization] Add quantization support for bitsandbytes (#9213)
* quantization config.

* fix-copies

* fix

* modules_to_not_convert

* add bitsandbytes utilities.

* make progress.

* fixes

* quality

* up

* up

rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312)

fix notes and dtype

up

up

* minor

* up

* up

* fix

* provide credits where due.

* make configurations work.

* fixes

* fix

* update_missing_keys

* fix

* fix

* make it work.

* fix

* provide credits to transformers.

* empty commit

* handle to() better.

* tests

* change to bnb from bitsandbytes

* fix tests

fix slow quality tests

SD3 remark

fix

complete int4 tests

add a readme to the test files.

add model cpu offload tests

warning test

* better safeguard.

* change merging status

* courtesy to transformers.

* move  upper.

* better

* make the unused kwargs warning friendlier.

* harmonize changes with https://github.com/huggingface/transformers/pull/33122

* style

* trainin tests

* feedback part i.

* Add Flux inpainting and Flux Img2Img (#9135)

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>

Update `UNet2DConditionModel`'s error messages (#9230)

* refactor

[CI] Update Single file Nightly Tests (#9357)

* update

* update

feedback.

improve README for flux dreambooth lora (#9290)

* improve readme

* improve readme

* improve readme

* improve readme

fix one uncaught deprecation warning for accessing vae_latent_channels in VaeImagePreprocessor (#9372)

deprecation warning vae_latent_channels

add mixed int8 tests and more tests to nf4.

[core] Freenoise memory improvements (#9262)

* update

* implement prompt interpolation

* make style

* resnet memory optimizations

* more memory optimizations; todo: refactor

* update

* update animatediff controlnet with latest changes

* refactor chunked inference changes

* remove print statements

* update

* chunk -> split

* remove changes from incorrect conflict resolution

* remove changes from incorrect conflict resolution

* add explanation of SplitInferenceModule

* update docs

* Revert "update docs"

This reverts commit c55a50a271.

* update docstring for freenoise split inference

* apply suggestions from review

* add tests

* apply suggestions from review

quantization docs.

docs.

* Revert "Add Flux inpainting and Flux Img2Img (#9135)"

This reverts commit 5799954dd4.

* tests

* don

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* contribution guide.

* changes

* empty

* fix tests

* harmonize with https://github.com/huggingface/transformers/pull/33546.

* numpy_cosine_distance

* config_dict modification.

* remove if config comment.

* note for load_state_dict changes.

* float8 check.

* quantizer.

* raise an error for non-True low_cpu_mem_usage values when using quant.

* low_cpu_mem_usage shenanigans when using fp32 modules.

* don't re-assign _pre_quantization_type.

* make comments clear.

* remove comments.

* handle mixed types better when moving to cpu.

* add tests to check if we're throwing warning rightly.

* better check.

* fix 8bit test_quality.

* handle dtype more robustly.

* better message when keep_in_fp32_modules.

* handle dtype casting.

* fix dtype checks in pipeline.

* fix warning message.

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* mitigate the confusing cpu warning

---------

Co-authored-by: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-10-21 10:11:57 +05:30

9.4 KiB

bitsandbytes

bitsandbytes is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.

4-bit quantization compresses a model even further, and it is commonly used with QLoRA to finetune quantized LLMs.

To use bitsandbytes, make sure you have the following libraries installed:

pip install diffusers transformers accelerate bitsandbytes -U

Now you can quantize a model by passing a [BitsAndBytesConfig] to [~ModelMixin.from_pretrained]. This works for any model in any modality, as long as it supports loading with Accelerate and contains torch.nn.Linear layers.

Quantizing a model in 8-bit halves the memory-usage:

from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    subfolder="transformer",
    quantization_config=quantization_config
)

By default, all the other modules such as torch.nn.LayerNorm are converted to torch.float16. You can change the data type of these modules with the torch_dtype parameter if you want:

from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=torch.float32
)
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype

Once a model is quantized, you can push the model to the Hub with the [~ModelMixin.push_to_hub] method. The quantization config.json file is pushed first, followed by the quantized model weights.

from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    subfolder="transformer",
    quantization_config=quantization_config
)

Quantizing a model in 4-bit reduces your memory-usage by 4x:

from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    subfolder="transformer",
    quantization_config=quantization_config
)

By default, all the other modules such as torch.nn.LayerNorm are converted to torch.float16. You can change the data type of these modules with the torch_dtype parameter if you want:

from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev", 
    subfolder="transformer",
    quantization_config=quantization_config,
    torch_dtype=torch.float32
)
model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype

Call [~ModelMixin.push_to_hub] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [~ModelMixin.save_pretrained].

Training with 8-bit and 4-bit weights are only supported for training extra parameters.

Check your memory footprint with the get_memory_footprint method:

print(model.get_memory_footprint())

Quantized models can be loaded from the [~ModelMixin.from_pretrained] method without needing to specify the quantization_config parameters:

from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
    "sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
)

8-bit (LLM.int8() algorithm)

Learn more about the details of 8-bit quantization in this blog post!

This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.

Outlier threshold

An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).

To find the best threshold for your model, we recommend experimenting with the llm_int8_threshold parameter in [BitsAndBytesConfig]:

from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True, llm_int8_threshold=10,
)

model_8bit = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="transformer",
    quantization_config=quantization_config,
)

Skip module conversion

For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like Stable Diffusion 3, the proj_out module can be skipped using the llm_int8_skip_modules parameter in [BitsAndBytesConfig]:

from diffusers import SD3Transformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_8bit=True, llm_int8_skip_modules=["proj_out"],
)

model_8bit = SD3Transformer2DModel.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    subfolder="transformer",
    quantization_config=quantization_config,
)

4-bit (QLoRA algorithm)

Learn more about its details in this blog post.

This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.

Compute data type

To speedup computation, you can change the data type from float32 (the default value) to bf16 using the bnb_4bit_compute_dtype parameter in [BitsAndBytesConfig]:

import torch
from diffusers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

Normal Float 4 (NF4)

NF4 is a 4-bit data type from the QLoRA paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the bnb_4bit_quant_type parameter in the [BitsAndBytesConfig]:

from diffusers import BitsAndBytesConfig

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
)

model_nf4 = SD3Transformer2DModel.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    subfolder="transformer",
    quantization_config=nf4_config,
)

For inference, the bnb_4bit_quant_type does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the bnb_4bit_compute_dtype and torch_dtype values.

Nested quantization

Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.

from diffusers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

double_quant_model = SD3Transformer2DModel.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    subfolder="transformer",
    quantization_config=double_quant_config,
)

Dequantizing bitsandbytes models

Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.

from diffusers import BitsAndBytesConfig

double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

double_quant_model = SD3Transformer2DModel.from_pretrained(
    "stabilityai/stable-diffusion-3-medium-diffusers",
    subfolder="transformer",
    quantization_config=double_quant_config,
)
model.dequantize()