From 880c0fdd365f3c91c2d65cbbf97df7d2ab98bd92 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 4 Dec 2023 19:38:44 +0200 Subject: [PATCH] [advanced dreambooth lora training script][bug_fix] change token_abstraction type to str (#6040) * improve help tags * style fix * changes token_abstraction type to string. support multiple concepts for pivotal using a comma separated string. * style fixup * changed logger to warning (not yet available) * moved the token_abstraction parsing to be in the same block as where we create the mapping of identifier to token --------- Co-authored-by: Linoy --- .../train_dreambooth_lora_sdxl_advanced.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 29fe2744ad..2bf3cc8f7c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -300,16 +300,18 @@ def parse_args(input_args=None): ) parser.add_argument( "--token_abstraction", + type=str, default="TOK", help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " - "captions - e.g. TOK", + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. " + "'TOK,TOK2,TOK3' etc.", ) parser.add_argument( "--num_new_tokens_per_abstraction", type=int, default=2, - help="number of new tokens inserted to the tokenizers per token_abstraction value when " + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " "tokens - ", ) @@ -660,17 +662,6 @@ def parse_args(input_args=None): "inversion training check `--train_text_encoder_ti`" ) - if args.train_text_encoder_ti: - if isinstance(args.token_abstraction, str): - args.token_abstraction = [args.token_abstraction] - elif isinstance(args.token_abstraction, List): - args.token_abstraction = args.token_abstraction - else: - raise ValueError( - f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. " - f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)" - ) - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -1155,9 +1146,14 @@ def main(args): ) if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = "".join(args.token_abstraction.split()).split(",") + logger.info(f"list of token identifiers: {token_abstraction_list}") + token_abstraction_dict = {} token_idx = 0 - for i, token in enumerate(args.token_abstraction): + for i, token in enumerate(token_abstraction_list): token_abstraction_dict[token] = [ f"" for j in range(args.num_new_tokens_per_abstraction) ]