From 1dd0ac9401bf89a115ed8adcb515b08f73cc9a46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radam=C3=A9s=20Ajna?= Date: Thu, 11 Jan 2024 19:45:39 -0800 Subject: [PATCH] [DPO Training] pass tracker name as argument (#6542) pass tracker name as argumentw --- .../diffusion_dpo/train_diffusion_dpo.py | 8 +++++++- .../diffusion_dpo/train_diffusion_dpo_sdxl.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py index 06bdd6af60..d1b76323e0 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -414,6 +414,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--tracker_name", + type=str, + default="diffusion-dpo-lora", + help=("The name of the tracker to report results to."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -726,7 +732,7 @@ def main(args): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("diffusion-dpo-lora", config=vars(args)) + accelerator.init_trackers(args.tracker_name, config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 0d109f741f..23e94bc679 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -429,6 +429,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--tracker_name", + type=str, + default="diffusion-dpo-lora-sdxl", + help=("The name of the tracker to report results to."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -821,7 +827,7 @@ def main(args): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("diffusion-dpo-lora-sdxl", config=vars(args)) + accelerator.init_trackers(args.tracker_name, config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps