mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Log Unconditional Image Generation Samples to W&B (#2287)
* Log Unconditional Image Generation Samples to WandB * Check for wandb installation and parity between onnxruntime script * Log epoch to wandb * Check for tensorboard logger early on * style fixes --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import diffusers
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
@@ -220,6 +220,7 @@ def parse_args():
|
||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||
)
|
||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
|
||||
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
@@ -271,6 +272,15 @@ def main(args):
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
if not is_tensorboard_available():
|
||||
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
|
||||
|
||||
elif args.logger == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -552,7 +562,7 @@ def main(args):
|
||||
generator=generator,
|
||||
batch_size=args.eval_batch_size,
|
||||
output_type="numpy",
|
||||
num_inference_steps=args.ddpm_num_steps,
|
||||
num_inference_steps=args.ddpm_num_inference_steps,
|
||||
).images
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
@@ -562,6 +572,11 @@ def main(args):
|
||||
accelerator.get_tracker("tensorboard").add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
elif args.logger == "wandb":
|
||||
accelerator.get_tracker("wandb").log(
|
||||
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
|
||||
@@ -22,7 +22,7 @@ import diffusers
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, is_tensorboard_available
|
||||
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
@@ -280,6 +280,15 @@ def main(args):
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
if not is_tensorboard_available():
|
||||
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
|
||||
|
||||
elif args.logger == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -604,10 +613,15 @@ def main(args):
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
|
||||
if args.logger == "tensorboard" and is_tensorboard_available():
|
||||
if args.logger == "tensorboard":
|
||||
accelerator.get_tracker("tensorboard").add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
elif args.logger == "wandb":
|
||||
accelerator.get_tracker("wandb").log(
|
||||
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
|
||||
Reference in New Issue
Block a user