From f07755fd04450b2fdffac56e20b49f51d852c8f8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 17 Sep 2024 03:50:57 +0200 Subject: [PATCH] import Dummy optim and scheduler only wheh required --- examples/cogvideo/train_cogvideox_lora.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index a61531a4cb..ff9183d780 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -26,7 +26,7 @@ import torch import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import DistributedDataParallelKwargs, DummyOptim, DummyScheduler, ProjectConfiguration, set_seed +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict from torch.utils.data import DataLoader, Dataset @@ -869,6 +869,8 @@ def prepare_rotary_positional_embeddings( def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): # Use DeepSpeed optimzer if use_deepspeed: + from accelerate.utils import DummyOptim + return DummyOptim( params_to_optimize, lr=args.learning_rate, @@ -1281,6 +1283,8 @@ def main(args): overrode_max_train_steps = True if use_deepspeed_scheduler: + from accelerate.utils import DummyScheduler + lr_scheduler = DummyScheduler( name=args.lr_scheduler, optimizer=optimizer,