diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py index 06bdd6af60..d1b76323e0 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py @@ -414,6 +414,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--tracker_name", + type=str, + default="diffusion-dpo-lora", + help=("The name of the tracker to report results to."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -726,7 +732,7 @@ def main(args): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("diffusion-dpo-lora", config=vars(args)) + accelerator.init_trackers(args.tracker_name, config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py index 0d109f741f..23e94bc679 100644 --- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py +++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py @@ -429,6 +429,12 @@ def parse_args(input_args=None): default=4, help=("The dimension of the LoRA update matrices."), ) + parser.add_argument( + "--tracker_name", + type=str, + default="diffusion-dpo-lora-sdxl", + help=("The name of the tracker to report results to."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -821,7 +827,7 @@ def main(args): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("diffusion-dpo-lora-sdxl", config=vars(args)) + accelerator.init_trackers(args.tracker_name, config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps