mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Log global_step instead of epoch to tensorboard (#4493)
Co-authored-by: mrlzla <noname@noname.com>
This commit is contained in:
committed by
GitHub
parent
9c29bc2df8
commit
2e69cf16fe
@@ -107,7 +107,16 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
|
||||
def log_validation(
|
||||
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
vae,
|
||||
args,
|
||||
accelerator,
|
||||
weight_dtype,
|
||||
global_step,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
@@ -173,7 +182,7 @@ def log_validation(
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
@@ -1308,7 +1317,7 @@ def main(args):
|
||||
args,
|
||||
accelerator,
|
||||
weight_dtype,
|
||||
epoch,
|
||||
global_step,
|
||||
validation_prompt_encoder_hidden_states,
|
||||
validation_prompt_negative_prompt_embeds,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user