mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[advanced dreambooth lora training script][bug_fix] change token_abstraction type to str (#6040)
* improve help tags * style fix * changes token_abstraction type to string. support multiple concepts for pivotal using a comma separated string. * style fixup * changed logger to warning (not yet available) * moved the token_abstraction parsing to be in the same block as where we create the mapping of identifier to token --------- Co-authored-by: Linoy <linoy@huggingface.co>
This commit is contained in:
@@ -300,16 +300,18 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token_abstraction",
|
||||
type=str,
|
||||
default="TOK",
|
||||
help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
|
||||
"captions - e.g. TOK",
|
||||
"captions - e.g. TOK. To use multiple identifiers, please specify them in a comma seperated string - e.g. "
|
||||
"'TOK,TOK2,TOK3' etc.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_new_tokens_per_abstraction",
|
||||
type=int,
|
||||
default=2,
|
||||
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
|
||||
help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
|
||||
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
|
||||
"tokens - <si><si+1> ",
|
||||
)
|
||||
@@ -660,17 +662,6 @@ def parse_args(input_args=None):
|
||||
"inversion training check `--train_text_encoder_ti`"
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
if isinstance(args.token_abstraction, str):
|
||||
args.token_abstraction = [args.token_abstraction]
|
||||
elif isinstance(args.token_abstraction, List):
|
||||
args.token_abstraction = args.token_abstraction
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. "
|
||||
f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)"
|
||||
)
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -1155,9 +1146,14 @@ def main(args):
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
# we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
|
||||
# TOK2" -> ["TOK", "TOK2"] etc.
|
||||
token_abstraction_list = "".join(args.token_abstraction.split()).split(",")
|
||||
logger.info(f"list of token identifiers: {token_abstraction_list}")
|
||||
|
||||
token_abstraction_dict = {}
|
||||
token_idx = 0
|
||||
for i, token in enumerate(args.token_abstraction):
|
||||
for i, token in enumerate(token_abstraction_list):
|
||||
token_abstraction_dict[token] = [
|
||||
f"<s{token_idx + i + j}>" for j in range(args.num_new_tokens_per_abstraction)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user