diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md index 118511b75d..266daa0193 100644 --- a/docs/source/en/quantization/bitsandbytes.md +++ b/docs/source/en/quantization/bitsandbytes.md @@ -17,6 +17,12 @@ specific language governing permissions and limitations under the License. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. +This guide demonstrates how quantization can enable running +[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) +on less than 16GB of VRAM and even on a free Google +Colab instance. + +![comparison image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/comparison.png) To use bitsandbytes, make sure you have the following libraries installed: @@ -31,70 +37,167 @@ Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixi Quantizing a model in 8-bit halves the memory-usage: +bitsandbytes is supported in both Transformers and Diffusers, so you can quantize both the +[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`]. + +For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`. + +> [!TIP] +> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers. + ```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig +from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig -quantization_config = BitsAndBytesConfig(load_in_8bit=True) +from diffusers import FluxTransformer2DModel +from transformers import T5EncoderModel -model_8bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", +quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,) + +text_encoder_2_8bit = T5EncoderModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder_2", + quantization_config=quant_config, + torch_dtype=torch.float16, +) + +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,) + +transformer_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="transformer", - quantization_config=quantization_config + quantization_config=quant_config, + torch_dtype=torch.float16, ) ``` -By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter. -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_8bit=True) - -model_8bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", +```diff +transformer_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.float32 + quantization_config=quant_config, ++ torch_dtype=torch.float32, ) -model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype ``` -Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`]. +Let's generate an image using our quantized models. + +Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the +CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory. + +```py +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=transformer_8bit, + text_encoder_2=text_encoder_2_8bit, + torch_dtype=torch.float16, + device_map="auto", +) + +pipe_kwargs = { + "prompt": "A cat holding a sign that says hello world", + "height": 1024, + "width": 1024, + "guidance_scale": 3.5, + "num_inference_steps": 50, + "max_sequence_length": 512, +} + +image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0] +``` + +
+ +
+ +When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage. + +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 8-bit models locally with [`~ModelMixin.save_pretrained`]. Quantizing a model in 4-bit reduces your memory-usage by 4x: +bitsandbytes is supported in both Transformers and Diffusers, so you can can quantize both the +[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`]. + +For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`. + +> [!TIP] +> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers. + ```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig +from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig -quantization_config = BitsAndBytesConfig(load_in_4bit=True) +from diffusers import FluxTransformer2DModel +from transformers import T5EncoderModel -model_4bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", +quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,) + +text_encoder_2_4bit = T5EncoderModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder_2", + quantization_config=quant_config, + torch_dtype=torch.float16, +) + +quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,) + +transformer_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="transformer", - quantization_config=quantization_config + quantization_config=quant_config, + torch_dtype=torch.float16, ) ``` -By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter. -```py -from diffusers import FluxTransformer2DModel, BitsAndBytesConfig - -quantization_config = BitsAndBytesConfig(load_in_4bit=True) - -model_4bit = FluxTransformer2DModel.from_pretrained( - "black-forest-labs/FLUX.1-dev", +```diff +transformer_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="transformer", - quantization_config=quantization_config, - torch_dtype=torch.float32 + quantization_config=quant_config, ++ torch_dtype=torch.float32, ) -model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype ``` -Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`]. +Let's generate an image using our quantized models. + +Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory. + +```py +pipe = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + transformer=transformer_4bit, + text_encoder_2=text_encoder_2_4bit, + torch_dtype=torch.float16, + device_map="auto", +) + +pipe_kwargs = { + "prompt": "A cat holding a sign that says hello world", + "height": 1024, + "width": 1024, + "guidance_scale": 3.5, + "num_inference_steps": 50, + "max_sequence_length": 512, +} + +image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0] +``` + +
+ +
+ +When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage. + +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
@@ -199,17 +302,34 @@ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dty NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]: ```py -from diffusers import BitsAndBytesConfig +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig +from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig -nf4_config = BitsAndBytesConfig( +from diffusers import FluxTransformer2DModel +from transformers import T5EncoderModel + +quant_config = TransformersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", ) -model_nf4 = SD3Transformer2DModel.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", +text_encoder_2_4bit = T5EncoderModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder_2", + quantization_config=quant_config, + torch_dtype=torch.float16, +) + +quant_config = DiffusersBitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", +) + +transformer_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="transformer", - quantization_config=nf4_config, + quantization_config=quant_config, + torch_dtype=torch.float16, ) ``` @@ -220,38 +340,74 @@ For inference, the `bnb_4bit_quant_type` does not have a huge impact on performa Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter. ```py -from diffusers import BitsAndBytesConfig +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig +from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig -double_quant_config = BitsAndBytesConfig( +from diffusers import FluxTransformer2DModel +from transformers import T5EncoderModel + +quant_config = TransformersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, ) -double_quant_model = SD3Transformer2DModel.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", +text_encoder_2_4bit = T5EncoderModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder_2", + quantization_config=quant_config, + torch_dtype=torch.float16, +) + +quant_config = DiffusersBitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +transformer_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", subfolder="transformer", - quantization_config=double_quant_config, + quantization_config=quant_config, + torch_dtype=torch.float16, ) ``` ## Dequantizing `bitsandbytes` models -Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model. +Once quantized, you can dequantize a model to its original precision, but this might result in a small loss of quality. Make sure you have enough GPU RAM to fit the dequantized model. ```python -from diffusers import BitsAndBytesConfig +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig +from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig -double_quant_config = BitsAndBytesConfig( +from diffusers import FluxTransformer2DModel +from transformers import T5EncoderModel + +quant_config = TransformersBitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, ) -double_quant_model = SD3Transformer2DModel.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", - subfolder="transformer", - quantization_config=double_quant_config, +text_encoder_2_4bit = T5EncoderModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="text_encoder_2", + quantization_config=quant_config, + torch_dtype=torch.float16, ) -model.dequantize() + +quant_config = DiffusersBitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +transformer_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.float16, +) + +text_encoder_2_4bit.dequantize() +transformer_4bit.dequantize() ``` ## Resources