diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 29fe2744ad..2bf3cc8f7c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -300,16 +300,18 @@ def parse_args(input_args=None): ) parser.add_argument( "--token_abstraction", + type=str, default="TOK", help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " - "captions - e.g. TOK", + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. " + "'TOK,TOK2,TOK3' etc.", ) parser.add_argument( "--num_new_tokens_per_abstraction", type=int, default=2, - help="number of new tokens inserted to the tokenizers per token_abstraction value when " + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " "tokens - ", ) @@ -660,17 +662,6 @@ def parse_args(input_args=None): "inversion training check `--train_text_encoder_ti`" ) - if args.train_text_encoder_ti: - if isinstance(args.token_abstraction, str): - args.token_abstraction = [args.token_abstraction] - elif isinstance(args.token_abstraction, List): - args.token_abstraction = args.token_abstraction - else: - raise ValueError( - f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. " - f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)" - ) - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -1155,9 +1146,14 @@ def main(args): ) if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = "".join(args.token_abstraction.split()).split(",") + logger.info(f"list of token identifiers: {token_abstraction_list}") + token_abstraction_dict = {} token_idx = 0 - for i, token in enumerate(args.token_abstraction): + for i, token in enumerate(token_abstraction_list): token_abstraction_dict[token] = [ f"" for j in range(args.num_new_tokens_per_abstraction) ]