1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[advanced dreambooth lora training script][bug_fix] change token_abstraction type to str (#6040)

* improve help tags

* style fix

* changes token_abstraction type to string.
support multiple concepts for pivotal using a comma separated string.

* style fixup

* changed logger to warning (not yet available)

* moved the token_abstraction parsing to be in the same block as where we create the mapping of identifier to token

---------

Co-authored-by: Linoy <linoy@huggingface.co>
This commit is contained in:
Linoy Tsaban
2023-12-04 19:38:44 +02:00
committed by GitHub
parent c36f1c3160
commit 880c0fdd36

View File

@@ -300,16 +300,18 @@ def parse_args(input_args=None):
)
parser.add_argument(
"--token_abstraction",
type=str,
default="TOK",
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
"captions - e.g. TOK",
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
"'TOK,TOK2,TOK3' etc.",
)
parser.add_argument(
"--num_new_tokens_per_abstraction",
type=int,
default=2,
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
"tokens - <si><si+1> ",
)
@@ -660,17 +662,6 @@ def parse_args(input_args=None):
"inversion training check `--train_text_encoder_ti`"
)
if args.train_text_encoder_ti:
if isinstance(args.token_abstraction, str):
args.token_abstraction = [args.token_abstraction]
elif isinstance(args.token_abstraction, List):
args.token_abstraction = args.token_abstraction
else:
raise ValueError(
f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
)
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -1155,9 +1146,14 @@ def main(args):
)
if args.train_text_encoder_ti:
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
# TOK2" -> ["TOK", "TOK2"] etc.
token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
logger.info(f"list of token identifiers: {token_abstraction_list}")
token_abstraction_dict = {}
token_idx = 0
for i, token in enumerate(args.token_abstraction):
for i, token in enumerate(token_abstraction_list):
token_abstraction_dict[token] = [
f"<s{token_idx + i + j}>" for j in range(args.num_new_tokens_per_abstraction)
]