mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[training] fixes to the quantization training script and add AdEMAMix optimizer as an option (#9806)
* fixes * more fixes.
This commit is contained in:
@@ -349,7 +349,7 @@ def parse_args(input_args=None):
|
||||
"--optimizer",
|
||||
type=str,
|
||||
default="AdamW",
|
||||
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
|
||||
choices=["AdamW", "Prodigy", "AdEMAMix"],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -357,6 +357,11 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_ademamix",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit AdEMAMix from bitsandbytes.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
|
||||
@@ -820,19 +825,18 @@ def main(args):
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.use_8bit_ademamix and not args.optimizer.lower() == "ademamix":
|
||||
logger.warning(
|
||||
f"use_8bit_ademamix is ignored when optimizer is not set to 'AdEMAMix'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "adamw":
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
@@ -853,6 +857,20 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
elif args.optimizer.lower() == "ademamix":
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use AdEMAMix (or its 8bit variant), please install the bitsandbytes library: `pip install -U bitsandbytes`."
|
||||
)
|
||||
if args.use_8bit_ademamix:
|
||||
optimizer_class = bnb.optim.AdEMAMix8bit
|
||||
else:
|
||||
optimizer_class = bnb.optim.AdEMAMix
|
||||
|
||||
optimizer = optimizer_class(params_to_optimize)
|
||||
|
||||
if args.optimizer.lower() == "prodigy":
|
||||
try:
|
||||
import prodigyopt
|
||||
@@ -868,7 +886,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -1020,12 +1037,12 @@ def main(args):
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
|
||||
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
|
||||
model_input.shape[0],
|
||||
model_input.shape[2],
|
||||
model_input.shape[3],
|
||||
model_input.shape[2] // 2,
|
||||
model_input.shape[3] // 2,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
@@ -1059,7 +1076,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# handle guidance
|
||||
if transformer.config.guidance_embeds:
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
@@ -1082,8 +1099,8 @@ def main(args):
|
||||
)[0]
|
||||
model_pred = FluxPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=int(model_input.shape[2] * vae_scale_factor / 2),
|
||||
width=int(model_input.shape[3] * vae_scale_factor / 2),
|
||||
height=model_input.shape[2] * vae_scale_factor,
|
||||
width=model_input.shape[3] * vae_scale_factor,
|
||||
vae_scale_factor=vae_scale_factor,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user