mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Guide] Quantize your Diffusion Models with bnb (#10012)
* chore: initial draft * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * chore: link in place * chore: review suggestions * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * chore: review suggestions * Update docs/source/en/quantization/bitsandbytes.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * review suggestions * chore: review suggestions * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * adding same changes to 4 bit section * review suggestions --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
3335e2262d
commit
bf64b32652
@@ -17,6 +17,12 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
|
||||
|
||||
This guide demonstrates how quantization can enable running
|
||||
[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
on less than 16GB of VRAM and even on a free Google
|
||||
Colab instance.
|
||||
|
||||

|
||||
|
||||
To use bitsandbytes, make sure you have the following libraries installed:
|
||||
|
||||
@@ -31,70 +37,167 @@ Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixi
|
||||
|
||||
Quantizing a model in 8-bit halves the memory-usage:
|
||||
|
||||
bitsandbytes is supported in both Transformers and Diffusers, so you can quantize both the
|
||||
[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
|
||||
|
||||
For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.
|
||||
|
||||
> [!TIP]
|
||||
> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
model_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
|
||||
|
||||
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
|
||||
|
||||
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
model_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
```diff
|
||||
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float32
|
||||
quantization_config=quant_config,
|
||||
+ torch_dtype=torch.float32,
|
||||
)
|
||||
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
|
||||
```
|
||||
|
||||
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
Let's generate an image using our quantized models.
|
||||
|
||||
Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the
|
||||
CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
|
||||
|
||||
```py
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer_8bit,
|
||||
text_encoder_2=text_encoder_2_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
pipe_kwargs = {
|
||||
"prompt": "A cat holding a sign that says hello world",
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"guidance_scale": 3.5,
|
||||
"num_inference_steps": 50,
|
||||
"max_sequence_length": 512,
|
||||
}
|
||||
|
||||
image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/8bit.png"/>
|
||||
</div>
|
||||
|
||||
When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.
|
||||
|
||||
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 8-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="4-bit">
|
||||
|
||||
Quantizing a model in 4-bit reduces your memory-usage by 4x:
|
||||
|
||||
bitsandbytes is supported in both Transformers and Diffusers, so you can can quantize both the
|
||||
[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
|
||||
|
||||
For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.
|
||||
|
||||
> [!TIP]
|
||||
> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
model_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,)
|
||||
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
```diff
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float32
|
||||
quantization_config=quant_config,
|
||||
+ torch_dtype=torch.float32,
|
||||
)
|
||||
model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
|
||||
```
|
||||
|
||||
Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
Let's generate an image using our quantized models.
|
||||
|
||||
Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
|
||||
|
||||
```py
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer_4bit,
|
||||
text_encoder_2=text_encoder_2_4bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
pipe_kwargs = {
|
||||
"prompt": "A cat holding a sign that says hello world",
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"guidance_scale": 3.5,
|
||||
"num_inference_steps": 50,
|
||||
"max_sequence_length": 512,
|
||||
}
|
||||
|
||||
image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/4bit.png"/>
|
||||
</div>
|
||||
|
||||
When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.
|
||||
|
||||
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@@ -199,17 +302,34 @@ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dty
|
||||
NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:
|
||||
|
||||
```py
|
||||
from diffusers import BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
model_nf4 = SD3Transformer2DModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=nf4_config,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -220,38 +340,74 @@ For inference, the `bnb_4bit_quant_type` does not have a huge impact on performa
|
||||
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.
|
||||
|
||||
```py
|
||||
from diffusers import BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
double_quant_config = BitsAndBytesConfig(
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
double_quant_model = SD3Transformer2DModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=double_quant_config,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
## Dequantizing `bitsandbytes` models
|
||||
|
||||
Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.
|
||||
Once quantized, you can dequantize a model to its original precision, but this might result in a small loss of quality. Make sure you have enough GPU RAM to fit the dequantized model.
|
||||
|
||||
```python
|
||||
from diffusers import BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
double_quant_config = BitsAndBytesConfig(
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
double_quant_model = SD3Transformer2DModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=double_quant_config,
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
model.dequantize()
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text_encoder_2_4bit.dequantize()
|
||||
transformer_4bit.dequantize()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
Reference in New Issue
Block a user