1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

import Dummy optim and scheduler only wheh required

This commit is contained in:
Aryan
2024-09-17 03:50:57 +02:00
parent 6d704ce770
commit f07755fd04

View File

@@ -26,7 +26,7 @@ import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, DummyOptim, DummyScheduler, ProjectConfiguration, set_seed
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import DataLoader, Dataset
@@ -869,6 +869,8 @@ def prepare_rotary_positional_embeddings(
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=args.learning_rate,
@@ -1281,6 +1283,8 @@ def main(args):
overrode_max_train_steps = True
if use_deepspeed_scheduler:
from accelerate.utils import DummyScheduler
lr_scheduler = DummyScheduler(
name=args.lr_scheduler,
optimizer=optimizer,