mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix error reported 'find_unused_parameters' running in mutiple GPUs (#5355)
* fix error reported 'find_unused_parameters' running in mutiple GPUs or NPUs * fix code check of importing module by its alphabetic order --------- Co-authored-by: jiaqiw <wangjiaqi50@huawei.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -31,7 +31,7 @@ import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
@@ -579,12 +579,13 @@ def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
|
||||
@@ -33,7 +33,7 @@ import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
@@ -491,12 +491,13 @@ def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
|
||||
Reference in New Issue
Block a user