From f4dddaf5ee3ea784a700676abe2d48c9fc3feecf Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 24 Jan 2023 10:25:41 +0100 Subject: [PATCH] [textual_inversion] Fix resuming state when using gradient checkpointing (#2072) * Fix resuming state when using gradient checkpointing. Also, allow --resume_from_checkpoint to be used when the checkpoint does not yet exist (a normal training run will be started). * style --- .../textual_inversion/textual_inversion.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 1e52036da9..0b3515ff8e 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -597,7 +597,7 @@ def main(): text_encoder, optimizer, train_dataloader, lr_scheduler ) - # For mixed precision training we cast the text_encoder and vae weights to half-precision + # For mixed precision training we cast the unet and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": @@ -643,14 +643,21 @@ def main(): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)