1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add option for 8bit adam

This commit is contained in:
patil-suraj
2022-09-26 17:02:02 +02:00
parent 16ecc089e3
commit 87bc75231a

View File

@@ -149,6 +149,9 @@ def parse_args():
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
@@ -401,7 +404,19 @@ def main():
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
optimizer = torch.optim.AdamW(
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
unet.parameters(), # only optimize unet
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),