From e140c0562ed87ec37273796d3e44d2f3aae4d9c2 Mon Sep 17 00:00:00 2001 From: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com> Date: Fri, 27 Oct 2023 12:19:14 -0500 Subject: [PATCH] fix error reported 'find_unused_parameters' running in mutiple GPUs (#5355) * fix error reported 'find_unused_parameters' running in mutiple GPUs or NPUs * fix code check of importing module by its alphabetic order --------- Co-authored-by: jiaqiw Co-authored-by: Dhruv Nair --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 5 +++-- examples/text_to_image/train_text_to_image_lora_sdxl.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index d7df6d4ef5..b729f7e189 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -31,7 +31,7 @@ import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image @@ -579,12 +579,13 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, + kwargs_handlers=[kwargs], ) if args.report_to == "wandb": diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 35de6eedca..74fc01aee3 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -33,7 +33,7 @@ import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version @@ -491,12 +491,13 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, + kwargs_handlers=[kwargs], ) if args.report_to == "wandb":