* quantization config. * fix-copies * fix * modules_to_not_convert * add bitsandbytes utilities. * make progress. * fixes * quality * up * up rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312) fix notes and dtype up up * minor * up * up * fix * provide credits where due. * make configurations work. * fixes * fix * update_missing_keys * fix * fix * make it work. * fix * provide credits to transformers. * empty commit * handle to() better. * tests * change to bnb from bitsandbytes * fix tests fix slow quality tests SD3 remark fix complete int4 tests add a readme to the test files. add model cpu offload tests warning test * better safeguard. * change merging status * courtesy to transformers. * move upper. * better * make the unused kwargs warning friendlier. * harmonize changes with https://github.com/huggingface/transformers/pull/33122 * style * trainin tests * feedback part i. * Add Flux inpainting and Flux Img2Img (#9135) --------- Co-authored-by: yiyixuxu <yixu310@gmail.com> Update `UNet2DConditionModel`'s error messages (#9230) * refactor [CI] Update Single file Nightly Tests (#9357) * update * update feedback. improve README for flux dreambooth lora (#9290) * improve readme * improve readme * improve readme * improve readme fix one uncaught deprecation warning for accessing vae_latent_channels in VaeImagePreprocessor (#9372) deprecation warning vae_latent_channels add mixed int8 tests and more tests to nf4. [core] Freenoise memory improvements (#9262) * update * implement prompt interpolation * make style * resnet memory optimizations * more memory optimizations; todo: refactor * update * update animatediff controlnet with latest changes * refactor chunked inference changes * remove print statements * update * chunk -> split * remove changes from incorrect conflict resolution * remove changes from incorrect conflict resolution * add explanation of SplitInferenceModule * update docs * Revert "update docs" This reverts commitc55a50a271. * update docstring for freenoise split inference * apply suggestions from review * add tests * apply suggestions from review quantization docs. docs. * Revert "Add Flux inpainting and Flux Img2Img (#9135)" This reverts commit5799954dd4. * tests * don * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * contribution guide. * changes * empty * fix tests * harmonize with https://github.com/huggingface/transformers/pull/33546. * numpy_cosine_distance * config_dict modification. * remove if config comment. * note for load_state_dict changes. * float8 check. * quantizer. * raise an error for non-True low_cpu_mem_usage values when using quant. * low_cpu_mem_usage shenanigans when using fp32 modules. * don't re-assign _pre_quantization_type. * make comments clear. * remove comments. * handle mixed types better when moving to cpu. * add tests to check if we're throwing warning rightly. * better check. * fix 8bit test_quality. * handle dtype more robustly. * better message when keep_in_fp32_modules. * handle dtype casting. * fix dtype checks in pipeline. * fix warning message. * Update src/diffusers/models/modeling_utils.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * mitigate the confusing cpu warning --------- Co-authored-by: Vishnu V Jaddipal <95531133+Gothos@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
9.4 KiB
bitsandbytes
bitsandbytes is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.
4-bit quantization compresses a model even further, and it is commonly used with QLoRA to finetune quantized LLMs.
To use bitsandbytes, make sure you have the following libraries installed:
pip install diffusers transformers accelerate bitsandbytes -U
Now you can quantize a model by passing a [BitsAndBytesConfig] to [~ModelMixin.from_pretrained]. This works for any model in any modality, as long as it supports loading with Accelerate and contains torch.nn.Linear layers.
Quantizing a model in 8-bit halves the memory-usage:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
By default, all the other modules such as torch.nn.LayerNorm are converted to torch.float16. You can change the data type of these modules with the torch_dtype parameter if you want:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.float32
)
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
Once a model is quantized, you can push the model to the Hub with the [~ModelMixin.push_to_hub] method. The quantization config.json file is pushed first, followed by the quantized model weights.
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
Quantizing a model in 4-bit reduces your memory-usage by 4x:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
By default, all the other modules such as torch.nn.LayerNorm are converted to torch.float16. You can change the data type of these modules with the torch_dtype parameter if you want:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.float32
)
model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
Call [~ModelMixin.push_to_hub] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [~ModelMixin.save_pretrained].
Training with 8-bit and 4-bit weights are only supported for training extra parameters.
Check your memory footprint with the get_memory_footprint method:
print(model.get_memory_footprint())
Quantized models can be loaded from the [~ModelMixin.from_pretrained] method without needing to specify the quantization_config parameters:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
)
8-bit (LLM.int8() algorithm)
Learn more about the details of 8-bit quantization in this blog post!
This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.
Outlier threshold
An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).
To find the best threshold for your model, we recommend experimenting with the llm_int8_threshold parameter in [BitsAndBytesConfig]:
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_threshold=10,
)
model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config,
)
Skip module conversion
For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like Stable Diffusion 3, the proj_out module can be skipped using the llm_int8_skip_modules parameter in [BitsAndBytesConfig]:
from diffusers import SD3Transformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_skip_modules=["proj_out"],
)
model_8bit = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=quantization_config,
)
4-bit (QLoRA algorithm)
Learn more about its details in this blog post.
This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.
Compute data type
To speedup computation, you can change the data type from float32 (the default value) to bf16 using the bnb_4bit_compute_dtype parameter in [BitsAndBytesConfig]:
import torch
from diffusers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
Normal Float 4 (NF4)
NF4 is a 4-bit data type from the QLoRA paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the bnb_4bit_quant_type parameter in the [BitsAndBytesConfig]:
from diffusers import BitsAndBytesConfig
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
model_nf4 = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=nf4_config,
)
For inference, the bnb_4bit_quant_type does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the bnb_4bit_compute_dtype and torch_dtype values.
Nested quantization
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.
from diffusers import BitsAndBytesConfig
double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
double_quant_model = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=double_quant_config,
)
Dequantizing bitsandbytes models
Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.
from diffusers import BitsAndBytesConfig
double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
double_quant_model = SD3Transformer2DModel.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
subfolder="transformer",
quantization_config=double_quant_config,
)
model.dequantize()