1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

deepspeed refactor

This commit is contained in:
Aryan
2024-09-16 13:54:34 +02:00
parent 0d95b0c5c0
commit ec8d483e72

View File

@@ -866,13 +866,16 @@ def prepare_rotary_positional_embeddings(
return freqs_cos, freqs_sin
def get_optimizer(accelerator, args, params_to_optimize):
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Use DeepSpeed optimzer
if (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
):
return DummyOptim(params_to_optimize, lr=args.learning_rate)
if use_deepspeed:
return DummyOptim(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
@@ -1214,7 +1217,16 @@ def main(args):
else:
params_to_optimize = [transformer_parameters_with_lr]
optimizer = get_optimizer(accelerator, args, params_to_optimize)
use_deepspeed_optimizer = (
accelerator.state.deepspeed_plugin is not None
and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config
)
use_deepspeed_scheduler = (
accelerator.state.deepspeed_plugin is not None
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
)
optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer)
# Dataset and DataLoader
train_dataset = VideoDataset(
@@ -1268,15 +1280,12 @@ def main(args):
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
if (
accelerator.state.deepspeed_plugin is not None
and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
if use_deepspeed_scheduler:
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,
total_num_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=args.lr_awrmup_steps * accelerator.num_processes,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
)
else:
lr_scheduler = get_scheduler(