1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[DPO Training] pass tracker name as argument (#6542)

pass tracker name as argumentw
This commit is contained in:
Radamés Ajna
2024-01-11 19:45:39 -08:00
committed by GitHub
parent c6b04589b6
commit 1dd0ac9401
2 changed files with 14 additions and 2 deletions

View File

@@ -414,6 +414,12 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--tracker_name",
type=str,
default="diffusion-dpo-lora",
help=("The name of the tracker to report results to."),
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -726,7 +732,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("diffusion-dpo-lora", config=vars(args))
accelerator.init_trackers(args.tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

View File

@@ -429,6 +429,12 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--tracker_name",
type=str,
default="diffusion-dpo-lora-sdxl",
help=("The name of the tracker to report results to."),
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -821,7 +827,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("diffusion-dpo-lora-sdxl", config=vars(args))
accelerator.init_trackers(args.tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps