From badddee0ef8282e07232dcd3990070e2d524f57c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 15 Dec 2022 19:49:40 +0100 Subject: [PATCH] Add state checkpointing to other training scripts (#1687) * Add state checkpointing to other training scripts * Fix first_epoch * Apply suggestions from code review Co-authored-by: Patrick von Platen * Update Dreambooth checkpoint help message. * Dreambooth docs: checkpoints, inference from a checkpoint. * make style Co-authored-by: Patrick von Platen --- docs/source/training/dreambooth.mdx | 67 +++++++++++++++---- examples/dreambooth/train_dreambooth.py | 3 +- .../train_dreambooth_inpaint.py | 59 ++++++++++++++-- examples/text_to_image/train_text_to_image.py | 55 ++++++++++++++- .../textual_inversion/textual_inversion.py | 58 ++++++++++++++-- .../train_unconditional.py | 54 ++++++++++++++- 6 files changed, 269 insertions(+), 27 deletions(-) diff --git a/docs/source/training/dreambooth.mdx b/docs/source/training/dreambooth.mdx index 238dcb24cf..6cea3fb773 100644 --- a/docs/source/training/dreambooth.mdx +++ b/docs/source/training/dreambooth.mdx @@ -21,8 +21,6 @@ The [Dreambooth training script](https://github.com/huggingface/diffusers/tree/m - - Dreambooth fine-tuning is very sensitive to hyperparameters and easy to overfit. We recommend you take a look at our [in-depth analysis](https://huggingface.co/blog/dreambooth) with recommended settings for different subjects, and go from there. @@ -44,17 +42,9 @@ Then initialize and configure a [🤗 Accelerate](https://github.com/huggingface accelerate config ``` -You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. +In this example we'll use model version `v1-4`, so please visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4) and carefully read the license before proceeding. -You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). - -Run the following command to authenticate your token - -```bash -huggingface-cli login -``` - -If you have already cloned the repo, then you won't need to go through these steps. Instead, you can pass the path to your local checkout to the training script and it will be loaded from there. +The command below will download and cache the model weights from the Hub because we use the model's Hub id `CompVis/stable-diffusion-v1-4`. You may also clone the repo locally and use the local path in your system where the checkout was saved. ### Dog toy example @@ -111,6 +101,59 @@ accelerate launch train_dreambooth.py \ --max_train_steps=800 ``` +### Saving checkpoints while training + +It's easy to overfit while training with Dreambooth, so sometimes it's useful to save regular checkpoints during the process. One of the intermediate checkpoints might work better than the final model! To use this feature you need to pass the following argument to the training script: + +```bash + --checkpointing_steps=500 +``` + +This will save the full training state in subfolders of your `output_dir`. Subfolder names begin with the prefix `checkpoint-`, and then the number of steps performed so far; for example: `checkpoint-1500` would be a checkpoint saved after 1500 training steps. + +#### Resuming training from a saved checkpoint + +If you want to resume training from any of the saved checkpoints, you can pass the argument `--resume_from_checkpoint` and then indicate the name of the checkpoint you want to use. You can also use the special string `"latest"` to resume from the last checkpoint saved (i.e., the one with the largest number of steps). For example, the following would resume training from the checkpoint saved after 1500 steps: + +```bash + --resume_from_checkpoint="checkpoint-1500" +``` + +This would be a good opportunity to tweak some of your hyperparameters if you wish. + +#### Performing inference using a saved checkpoint + +Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate. + +You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it: + +```python +from accelerate import Accelerator +from diffusers import DiffusionPipeline + +# Load the pipeline with the same arguments (model, revision) that were used for training +model_id = "CompVis/stable-diffusion-v1-4" +pipeline = DiffusionPipeline.from_pretrained(model_id) + +accelerator = Accelerator() + +# Use text_encoder if `--train_text_encoder` was used for the initial training +unet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder) + +# Restore state from a checkpoint path. You have to use the absolute path here. +accelerator.load_state("/sddata/dreambooth/daruma-v2-1/checkpoint-100") + +# Rebuild the pipeline with the unwrapped models (assignment to .unet and .text_encoder should work too) +pipeline = DiffusionPipeline.from_pretrained( + model_id, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), +) + +# Perform inference, or save, or push to the hub +pipeline.save_pretrained("dreambooth-pipeline") +``` + ### Training on a 16GB GPU With the help of gradient checkpointing and the 8-bit optimizer from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), it's possible to train dreambooth on a 16GB GPU. diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index adc1359a5e..de735141d5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -155,7 +155,8 @@ def parse_args(input_args=None): type=int, default=500, help=( - "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" " training using `--resume_from_checkpoint`." ), ) diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py index b7f6900926..e1280ede45 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py @@ -242,6 +242,25 @@ def parse_args(): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint and are suitable for resuming training" + " using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -591,6 +610,7 @@ def main(): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) + accelerator.register_for_checkpointing(lr_scheduler) weight_dtype = torch.float32 if args.mixed_precision == "fp16": @@ -628,14 +648,39 @@ def main(): logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") global_step = 0 + first_epoch = 0 - for epoch in range(args.num_train_epochs): + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_epochs): unet.train() for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + with accelerator.accumulate(unet): # Convert images to latent space @@ -719,6 +764,12 @@ def main(): progress_bar.update(1) global_step += 1 + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 08386728bc..6c45ee0b1b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -216,6 +216,24 @@ def parse_args(): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -528,6 +546,7 @@ def main(): unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, optimizer, train_dataloader, lr_scheduler ) + accelerator.register_for_checkpointing(lr_scheduler) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -567,16 +586,40 @@ def main(): logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) progress_bar.set_description("Steps") - global_step = 0 - for epoch in range(args.num_train_epochs): + for epoch in range(first_epoch, args.num_train_epochs): unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() @@ -629,6 +672,12 @@ def main(): accelerator.log({"train_loss": train_loss}, step=global_step) train_loss = 0.0 + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 380a8b4738..a2e023a5ce 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -205,6 +205,24 @@ def parse_args(): ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -512,6 +530,7 @@ def main(): text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler ) + accelerator.register_for_checkpointing(lr_scheduler) # Move vae and unet to device vae.to(accelerator.device) @@ -543,17 +562,42 @@ def main(): logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") global_step = 0 + first_epoch = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") # keep original embeddings as reference orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone() - for epoch in range(args.num_train_epochs): + for epoch in range(first_epoch, args.num_train_epochs): text_encoder.train() for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + with accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() @@ -605,6 +649,12 @@ def main(): save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 8f2fcce268..53b0d53253 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -194,7 +194,6 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) - parser.add_argument( "--prediction_type", type=str, @@ -202,9 +201,26 @@ def parse_args(): choices=["epsilon", "sample"], help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.", ) - parser.add_argument("--ddpm_num_steps", type=int, default=1000) parser.add_argument("--ddpm_beta_schedule", type=str, default="linear") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -319,6 +335,7 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) + accelerator.register_for_checkpointing(lr_scheduler) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -351,11 +368,36 @@ def main(args): accelerator.init_trackers(run) global_step = 0 - for epoch in range(args.num_epochs): + first_epoch = 0 + + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + for epoch in range(first_epoch, args.num_epochs): model.train() progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + clean_images = batch["input"] # Sample noise that we'll add to the images noise = torch.randn(clean_images.shape).to(clean_images.device) @@ -402,6 +444,12 @@ def main(args): progress_bar.update(1) global_step += 1 + if global_step % args.checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} if args.use_ema: logs["ema_decay"] = ema_model.decay