mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[training] feat: enable quantization for hidream lora training. (#11494)
* feat: enable quantization for hidream lora training. * better handle compute dtype. * finalize. * fix dtype. --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -117,3 +117,30 @@ We provide several options for optimizing memory optimization:
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
|
||||
|
||||
## Using quantization
|
||||
|
||||
You can quantize the base model with [`bitsandbytes`](https://huggingface.co/docs/bitsandbytes/index) to reduce memory usage. To do so, pass a JSON file path to `--bnb_quantization_config_path`. This file should hold the configuration to initialize `BitsAndBytesConfig`. Below is an example JSON file:
|
||||
|
||||
```json
|
||||
{
|
||||
"load_in_4bit": true,
|
||||
"bnb_4bit_quant_type": "nf4"
|
||||
}
|
||||
```
|
||||
|
||||
Below, we provide some numbers with and without the use of NF4 quantization when training:
|
||||
|
||||
```
|
||||
(with quantization)
|
||||
Memory (before device placement): 9.085089683532715 GB.
|
||||
Memory (after device placement): 34.59585428237915 GB.
|
||||
Memory (after backward): 36.90267467498779 GB.
|
||||
|
||||
(without quantization)
|
||||
Memory (before device placement): 0.0 GB.
|
||||
Memory (after device placement): 57.6400408744812 GB.
|
||||
Memory (after backward): 59.932212829589844 GB.
|
||||
```
|
||||
|
||||
The reason why we see some memory before device placement in the case of quantization is because, by default bnb quantized models are placed on the GPU first.
|
||||
@@ -16,6 +16,7 @@
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -27,14 +28,13 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
@@ -47,6 +47,7 @@ from transformers import AutoTokenizer, CLIPTokenizer, LlamaForCausalLM, Pretrai
|
||||
import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
BitsAndBytesConfig,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HiDreamImagePipeline,
|
||||
HiDreamImageTransformer2DModel,
|
||||
@@ -282,6 +283,12 @@ def parse_args(input_args=None):
|
||||
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bnb_quantization_config_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Quantization config in a JSON file that will be used to define the bitsandbytes quant config of the DiT.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -1056,6 +1063,14 @@ def main(args):
|
||||
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.revision, shift=3.0
|
||||
@@ -1064,20 +1079,31 @@ def main(args):
|
||||
text_encoder_one, text_encoder_two, text_encoder_three, text_encoder_four = load_text_encoders(
|
||||
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
|
||||
)
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
quantization_config = None
|
||||
if args.bnb_quantization_config_path is not None:
|
||||
with open(args.bnb_quantization_config_path, "r") as f:
|
||||
config_kwargs = json.load(f)
|
||||
if "load_in_4bit" in config_kwargs and config_kwargs["load_in_4bit"]:
|
||||
config_kwargs["bnb_4bit_compute_dtype"] = weight_dtype
|
||||
quantization_config = BitsAndBytesConfig(**config_kwargs)
|
||||
|
||||
transformer = HiDreamImageTransformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=weight_dtype,
|
||||
force_inference_output=True,
|
||||
)
|
||||
if args.bnb_quantization_config_path is not None:
|
||||
transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False)
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
transformer.requires_grad_(False)
|
||||
@@ -1087,14 +1113,6 @@ def main(args):
|
||||
text_encoder_three.requires_grad_(False)
|
||||
text_encoder_four.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
@@ -1109,7 +1127,12 @@ def main(args):
|
||||
text_encoder_three.to(**to_kwargs)
|
||||
text_encoder_four.to(**to_kwargs)
|
||||
# we never offload the transformer to CPU, so we can just use the accelerator device
|
||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||
transformer_to_kwargs = (
|
||||
{"device": accelerator.device}
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
# Initialize a text encoding pipeline and keep it to CPU for now.
|
||||
text_encoding_pipeline = HiDreamImagePipeline.from_pretrained(
|
||||
@@ -1695,10 +1718,11 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
HiDreamImagePipeline.save_lora_weights(
|
||||
|
||||
@@ -179,7 +179,7 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `bitsandbytes`.
|
||||
|
||||
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
|
||||
This replaces `load_in_8bit` or `load_in_4bit` therefore both options are mutually exclusive.
|
||||
|
||||
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
|
||||
then more arguments will be added to this class.
|
||||
|
||||
Reference in New Issue
Block a user