From 768d0ea6fa6a305d12df1feda2afae3ec80aa449 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 14 Aug 2025 08:04:33 +0200 Subject: [PATCH] try to make dreambooth script work; accelerator backward not playing well --- .../train_dreambooth_lora_flux_advanced.py | 31 +++++++++++++++++-- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index 9fea299421..4207efd214 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -67,6 +67,8 @@ from diffusers import ( FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, + ParallelConfig, + enable_parallelism, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( @@ -805,6 +807,8 @@ def parse_args(input_args=None): ], help="The image interpolation method to use for resizing images.", ) + parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.") + parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.") if input_args is not None: args = parser.parse_args(input_args) @@ -1347,15 +1351,28 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) + cp_degree = args.context_parallel_degree accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + if cp_degree > 1: + kwargs = [] + else: + kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)] accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, - kwargs_handlers=[kwargs], + kwargs_handlers=kwargs, ) + if cp_degree > 1 and not torch.distributed.is_initialized(): + if not torch.cuda.is_available(): + raise ValueError("Context parallelism is only tested on CUDA devices.") + if os.environ.get("WORLD_SIZE", None) is None: + raise ValueError("Try launching the program with `torchrun --nproc_per_node ` instead of `accelerate launch `.") + torch.distributed.init_process_group("nccl") + rank = torch.distributed.get_rank() + rank = accelerator.process_index + torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count())) # Disable AMP for MPS. if torch.backends.mps.is_available(): @@ -1977,6 +1994,14 @@ def main(args): power=args.lr_power, ) + # Enable context parallelism + if cp_degree > 1: + ring_degree = cp_degree if args.context_parallel_type == "ring" else None + ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None + transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree)) + transformer.set_attention_backend("_native_cudnn") + parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext() + # Prepare everything with our `accelerator`. if not freeze_text_encoder: if args.enable_t5_ti: @@ -2131,7 +2156,7 @@ def main(args): logger.info(f"PIVOT TRANSFORMER {epoch}") optimizer.param_groups[0]["lr"] = 0.0 - with accelerator.accumulate(models_to_accumulate): + with accelerator.accumulate(models_to_accumulate), parallel_context: prompts = batch["prompts"] # encode batch prompts when custom prompts are provided for each image -