1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Log Unconditional Image Generation Samples to W&B (#2287)

* Log Unconditional Image Generation Samples to WandB

* Check for wandb installation and parity between onnxruntime script

* Log epoch to wandb

* Check for tensorboard logger early on

* style fixes

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Ben Evans
2023-02-14 22:11:12 +00:00
committed by GitHub
parent 62b3c9e06a
commit 0db19da01f
2 changed files with 33 additions and 4 deletions

View File

@@ -21,7 +21,7 @@ import diffusers
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -220,6 +220,7 @@ def parse_args():
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
parser.add_argument(
"--checkpointing_steps",
@@ -271,6 +272,15 @@ def main(args):
logging_dir=logging_dir,
)
if args.logger == "tensorboard":
if not is_tensorboard_available():
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
elif args.logger == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -552,7 +562,7 @@ def main(args):
generator=generator,
batch_size=args.eval_batch_size,
output_type="numpy",
num_inference_steps=args.ddpm_num_steps,
num_inference_steps=args.ddpm_num_inference_steps,
).images
# denormalize the images and save to tensorboard
@@ -562,6 +572,11 @@ def main(args):
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
elif args.logger == "wandb":
accelerator.get_tracker("wandb").log(
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
step=global_step,
)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model

View File

@@ -22,7 +22,7 @@ import diffusers
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_tensorboard_available
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -280,6 +280,15 @@ def main(args):
logging_dir=logging_dir,
)
if args.logger == "tensorboard":
if not is_tensorboard_available():
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
elif args.logger == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -604,10 +613,15 @@ def main(args):
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
if args.logger == "tensorboard" and is_tensorboard_available():
if args.logger == "tensorboard":
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
elif args.logger == "wandb":
accelerator.get_tracker("wandb").log(
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
step=global_step,
)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model