diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c9c73efe9b..841849dcf3 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -685,9 +685,8 @@ def main(): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps - first_epoch = resume_global_step // num_update_steps_per_epoch - resume_step = resume_global_step % num_update_steps_per_epoch + first_epoch = global_step // num_update_steps_per_epoch + resume_step = global_step % num_update_steps_per_epoch # Only show the progress bar once on each machine. progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)