mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
apply suggestions from review; prodigy optimizer
YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -360,6 +360,13 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prodigy_beta3",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.",
|
||||
)
|
||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
@@ -371,6 +378,15 @@ def get_args():
|
||||
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
||||
)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prodigy_safeguard_warmup",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.",
|
||||
)
|
||||
|
||||
# Other information
|
||||
parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name")
|
||||
@@ -851,9 +867,9 @@ def get_optimizer(args, params_to_optimize):
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
@@ -883,6 +899,38 @@ def get_optimizer(args, params_to_optimize):
|
||||
eps=args.adam_epsilon,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
)
|
||||
elif args.optimizer.lower() == "prodigy":
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warning(
|
||||
f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# Changes the learning rate of text_encoder_parameters to be --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
decouple=args.prodigy_decouple,
|
||||
use_bias_correction=args.prodigy_use_bias_correction,
|
||||
safeguard_warmup=args.prodigy_safeguard_warmup,
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
@@ -958,8 +1006,15 @@ def main(args):
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
|
||||
# CogVideoX-2b weights are stored in float16
|
||||
# CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16
|
||||
load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16
|
||||
transformer = CogVideoXTransformer3DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
torch_dtype=load_dtype,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
vae = AutoencoderKLCogVideoX.from_pretrained(
|
||||
@@ -981,10 +1036,23 @@ def main(args):
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
if accelerator.state.deepspeed_plugin:
|
||||
# DeepSpeed is handling precision, use what's in the DeepSpeed config
|
||||
if (
|
||||
"fp16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"]
|
||||
):
|
||||
weight_dtype = torch.float16
|
||||
if (
|
||||
"bf16" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"]
|
||||
):
|
||||
weight_dtype = torch.float16
|
||||
else:
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
@@ -1279,6 +1347,9 @@ def main(args):
|
||||
)
|
||||
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
|
||||
# For DeepSpeed training
|
||||
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
@@ -1322,11 +1393,11 @@ def main(args):
|
||||
width=args.width,
|
||||
num_frames=num_frames,
|
||||
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
||||
patch_size=transformer.config.patch_size,
|
||||
attention_head_dim=transformer.config.attention_head_dim,
|
||||
patch_size=model_config.patch_size,
|
||||
attention_head_dim=model_config.attention_head_dim,
|
||||
device=accelerator.device,
|
||||
)
|
||||
if transformer.config.use_rotary_positional_embeddings
|
||||
if model_config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user