1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[training] fixes to the quantization training script and add AdEMAMix optimizer as an option (#9806)

* fixes

* more fixes.
This commit is contained in:
Sayak Paul
2024-10-31 15:46:00 +05:30
committed by GitHub
parent c1d4a0dded
commit 09b8aebd67

View File

@@ -349,7 +349,7 @@ def parse_args(input_args=None):
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
choices=["AdamW", "Prodigy", "AdEMAMix"],
)
parser.add_argument(
@@ -357,6 +357,11 @@ def parse_args(input_args=None):
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--use_8bit_ademamix",
action="store_true",
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
@@ -820,19 +825,18 @@ def main(args):
params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
logger.warning(
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
@@ -853,6 +857,20 @@ def main(args):
eps=args.adam_epsilon,
)
elif args.optimizer.lower() == "ademamix":
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
)
if args.use_8bit_ademamix:
optimizer_class = bnb.optim.AdEMAMix8bit
else:
optimizer_class = bnb.optim.AdEMAMix
optimizer = optimizer_class(params_to_optimize)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
@@ -868,7 +886,6 @@ def main(args):
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
@@ -1020,12 +1037,12 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
model_input.shape[2] // 2,
model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
@@ -1059,7 +1076,7 @@ def main(args):
)
# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1082,8 +1099,8 @@ def main(args):
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
height=model_input.shape[2] * vae_scale_factor,
width=model_input.shape[3] * vae_scale_factor,
vae_scale_factor=vae_scale_factor,
)