diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 7c2601d8e9..29259e408e 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -343,6 +343,12 @@ def parse_args(): "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -464,7 +470,11 @@ def main(): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=args.rank, + ) unet.set_attn_processor(lora_attn_procs)