From e01d6cf295043d5e98612d836ae2a281adcdf242 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 13 Dec 2022 15:16:44 +0100 Subject: [PATCH] Dreambooth: save / restore training state (#1668) * Dreambooth: save / restore training state. * make style * Rename vars for clarity. Co-authored-by: Patrick von Platen * Remove unused import Co-authored-by: Patrick von Platen --- examples/dreambooth/train_dreambooth.py | 74 +++++++++++++++++-------- 1 file changed, 51 insertions(+), 23 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 86a60728d5..9f46d8c775 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1,6 +1,5 @@ import argparse import hashlib -import inspect import itertools import math import os @@ -150,7 +149,24 @@ def parse_args(input_args=None): default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) - parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -579,6 +595,7 @@ def main(args): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) + accelerator.register_for_checkpointing(lr_scheduler) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -616,16 +633,41 @@ def main(args): logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") global_step = 0 + first_epoch = 0 - for epoch in range(args.num_train_epochs): + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() @@ -689,25 +731,11 @@ def main(args): progress_bar.update(1) global_step += 1 - if global_step % args.save_steps == 0: + if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process: - # When 'keep_fp32_wrapper' is `False` (the default), then the models are - # unwrapped and the mixed precision hooks are removed, so training crashes - # when the unwrapped models are used for further training. - # This is only supported in newer versions of `accelerate`. - # TODO(Pedro, Suraj): Remove `accepts_keep_fp32_wrapper` when forcing newer accelerate versions - accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set( - inspect.signature(accelerator.unwrap_model).parameters.keys() - ) - extra_args = {"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {} - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet, **extra_args), - text_encoder=accelerator.unwrap_model(text_encoder, **extra_args), - revision=args.revision, - ) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - pipeline.save_pretrained(save_path) + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs)