From 31336dae3b95a60069fc08ed4024b14f9b71fc11 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 24 Jan 2023 12:02:41 +0100 Subject: [PATCH] Fix resume epoch for all training scripts except textual_inversion (#2079) --- examples/dreambooth/train_dreambooth.py | 21 ++++++++++++------ examples/dreambooth/train_dreambooth_lora.py | 21 ++++++++++++------ .../train_dreambooth_inpaint.py | 21 ++++++++++++------ .../train_multi_subject_dreambooth.py | 21 ++++++++++++------ examples/text_to_image/train_text_to_image.py | 20 ++++++++++++----- .../text_to_image/train_text_to_image_lora.py | 21 ++++++++++++------ .../train_unconditional.py | 21 ++++++++++++------ .../train_unconditional_ort.py | 22 +++++++++++++------ 8 files changed, 113 insertions(+), 55 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0ae8f992fe..b344cc98fd 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -757,14 +757,21 @@ def main(args): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index febff04d1b..fb9f2c832e 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -814,14 +814,21 @@ def main(args): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py index 42ae527edc..fe3c5c793e 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py @@ -660,14 +660,21 @@ def main(): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py index 32a14e32ed..fada0b3946 100644 --- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py +++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py @@ -748,14 +748,21 @@ def main(args): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 2992317ff4..d771aebc2c 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -599,13 +599,21 @@ def main(): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - first_epoch = global_step // num_update_steps_per_epoch - resume_step = global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 0658c45a52..9b5104de50 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -651,14 +651,21 @@ def main(): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index a5b026a910..2e2b40a544 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -439,14 +439,21 @@ def main(args): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) + path = dirs[-1] if len(dirs) > 0 else None - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) # Train! for epoch in range(first_epoch, args.num_epochs): diff --git a/examples/unconditional_image_generation/train_unconditional_ort.py b/examples/unconditional_image_generation/train_unconditional_ort.py index f91179c643..8c5622c2e1 100644 --- a/examples/unconditional_image_generation/train_unconditional_ort.py +++ b/examples/unconditional_image_generation/train_unconditional_ort.py @@ -396,13 +396,21 @@ def main(args): dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) - global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = global_step // num_update_steps_per_epoch + resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) for epoch in range(first_epoch, args.num_epochs): model.train()