1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

try to make dreambooth script work; accelerator backward not playing well

This commit is contained in:
Aryan
2025-08-14 08:04:33 +02:00
parent cca53814a3
commit 768d0ea6fa

View File

@@ -67,6 +67,8 @@ from diffusers import (
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
FluxTransformer2DModel,
ParallelConfig,
enable_parallelism,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
@@ -805,6 +807,8 @@ def parse_args(input_args=None):
],
help="The image interpolation method to use for resizing images.",
)
parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.")
parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1347,15 +1351,28 @@ def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
cp_degree = args.context_parallel_degree
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
if cp_degree > 1:
kwargs = []
else:
kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)]
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
kwargs_handlers=[kwargs],
kwargs_handlers=kwargs,
)
if cp_degree > 1 and not torch.distributed.is_initialized():
if not torch.cuda.is_available():
raise ValueError("Context parallelism is only tested on CUDA devices.")
if os.environ.get("WORLD_SIZE", None) is None:
raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.")
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
rank = accelerator.process_index
torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count()))
# Disable AMP for MPS.
if torch.backends.mps.is_available():
@@ -1977,6 +1994,14 @@ def main(args):
power=args.lr_power,
)
# Enable context parallelism
if cp_degree > 1:
ring_degree = cp_degree if args.context_parallel_type == "ring" else None
ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None
transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree))
transformer.set_attention_backend("_native_cudnn")
parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext()
# Prepare everything with our `accelerator`.
if not freeze_text_encoder:
if args.enable_t5_ti:
@@ -2131,7 +2156,7 @@ def main(args):
logger.info(f"PIVOT TRANSFORMER {epoch}")
optimizer.param_groups[0]["lr"] = 0.0
with accelerator.accumulate(models_to_accumulate):
with accelerator.accumulate(models_to_accumulate), parallel_context:
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -