From 2e69cf16fea891eadd128d912c3c8a4254a007af Mon Sep 17 00:00:00 2001 From: Vladislav Artemyev Date: Mon, 7 Aug 2023 04:19:39 +0200 Subject: [PATCH] Log global_step instead of epoch to tensorboard (#4493) Co-authored-by: mrlzla --- examples/dreambooth/train_dreambooth.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 28559dd172..eed4df368b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -107,7 +107,16 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}. def log_validation( - text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype, + global_step, + prompt_embeds, + negative_prompt_embeds, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" @@ -173,7 +182,7 @@ def log_validation( for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { @@ -1308,7 +1317,7 @@ def main(args): args, accelerator, weight_dtype, - epoch, + global_step, validation_prompt_encoder_hidden_states, validation_prompt_negative_prompt_embeds, )