mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Small modification to enable usage by external scripts (#956)
* Make training code usable by external scripts Add parameter inputs to training and argument parsing function to allow this script to be used by an external call. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
def parse_args(input_args):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
@@ -196,7 +196,11 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -319,8 +323,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
@@ -653,4 +656,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
args = parse_args()
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user