mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
try to make dreambooth script work; accelerator backward not playing well
This commit is contained in:
@@ -67,6 +67,8 @@ from diffusers import (
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FluxPipeline,
|
||||
FluxTransformer2DModel,
|
||||
ParallelConfig,
|
||||
enable_parallelism,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
@@ -805,6 +807,8 @@ def parse_args(input_args=None):
|
||||
],
|
||||
help="The image interpolation method to use for resizing images.",
|
||||
)
|
||||
parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.")
|
||||
parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1347,15 +1351,28 @@ def main(args):
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
cp_degree = args.context_parallel_degree
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
if cp_degree > 1:
|
||||
kwargs = []
|
||||
else:
|
||||
kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)]
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
kwargs_handlers=[kwargs],
|
||||
kwargs_handlers=kwargs,
|
||||
)
|
||||
if cp_degree > 1 and not torch.distributed.is_initialized():
|
||||
if not torch.cuda.is_available():
|
||||
raise ValueError("Context parallelism is only tested on CUDA devices.")
|
||||
if os.environ.get("WORLD_SIZE", None) is None:
|
||||
raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.")
|
||||
torch.distributed.init_process_group("nccl")
|
||||
rank = torch.distributed.get_rank()
|
||||
rank = accelerator.process_index
|
||||
torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count()))
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -1977,6 +1994,14 @@ def main(args):
|
||||
power=args.lr_power,
|
||||
)
|
||||
|
||||
# Enable context parallelism
|
||||
if cp_degree > 1:
|
||||
ring_degree = cp_degree if args.context_parallel_type == "ring" else None
|
||||
ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None
|
||||
transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree))
|
||||
transformer.set_attention_backend("_native_cudnn")
|
||||
parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext()
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
if not freeze_text_encoder:
|
||||
if args.enable_t5_ti:
|
||||
@@ -2131,7 +2156,7 @@ def main(args):
|
||||
logger.info(f"PIVOT TRANSFORMER {epoch}")
|
||||
optimizer.param_groups[0]["lr"] = 0.0
|
||||
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
with accelerator.accumulate(models_to_accumulate), parallel_context:
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
|
||||
Reference in New Issue
Block a user