mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
deepspeed refactor
This commit is contained in:
@@ -866,13 +866,16 @@ def prepare_rotary_positional_embeddings(
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def get_optimizer(accelerator, args, params_to_optimize):
|
||||
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
# Use DeepSpeed optimzer
|
||||
if (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
return DummyOptim(params_to_optimize, lr=args.learning_rate)
|
||||
if use_deepspeed:
|
||||
return DummyOptim(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
eps=args.adam_epsilon,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
)
|
||||
|
||||
# Optimizer creation
|
||||
supported_optimizers = ["adam", "adamw", "prodigy"]
|
||||
@@ -1214,7 +1217,16 @@ def main(args):
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
optimizer = get_optimizer(accelerator, args, params_to_optimize)
|
||||
use_deepspeed_optimizer = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
use_deepspeed_scheduler = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
)
|
||||
|
||||
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
|
||||
|
||||
# Dataset and DataLoader
|
||||
train_dataset = VideoDataset(
|
||||
@@ -1268,15 +1280,12 @@ def main(args):
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
if (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
):
|
||||
if use_deepspeed_scheduler:
|
||||
lr_scheduler = DummyScheduler(
|
||||
name=args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
total_num_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_warmup_steps=args.lr_awrmup_steps * accelerator.num_processes,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = get_scheduler(
|
||||
|
||||
Reference in New Issue
Block a user