From cdf2ae8a8426d198a108242dc933c39763c8ccc3 Mon Sep 17 00:00:00 2001 From: takuoko Date: Thu, 29 Jun 2023 20:39:59 +0900 Subject: [PATCH] [Enhance] Add LoRA rank args in train_text_to_image_lora (#3866) * add rank args in lora finetune * del network_alpha --- examples/text_to_image/train_text_to_image_lora.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 7c2601d8e9..29259e408e 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -343,6 +343,12 @@ def parse_args(): "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -464,7 +470,11 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.rank, + ) unet.set_attn_processor(lora_attn_procs)