1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Examples] Update train_unconditional.py to include logging argument for Wandb (#1719)

Update train_unconditional.py

Add logger flag to choose between tensorboard and wandb
This commit is contained in:
Anish Shah
2022-12-19 10:57:03 -05:00
committed by GitHub
parent ce1c27adc8
commit 9f657f106d

View File

@@ -173,6 +173,16 @@ def parse_args():
parser.add_argument(
"--hub_private_repo", action="store_true", help="Whether or not to create a private repository."
)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
help=(
"Whether to use [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://www.wandb.ai)"
" for experiment tracking and logging of model metrics and model checkpoints"
),
)
parser.add_argument(
"--logging_dir",
type=str,
@@ -248,7 +258,7 @@ def main(args):
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with="tensorboard",
log_with=args.logger,
logging_dir=logging_dir,
)
@@ -477,9 +487,11 @@ def main(args):
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
accelerator.trackers[0].writer.add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
if args.logger == "tensorboard":
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model