1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Enhance] Add LoRA rank args in train_text_to_image_lora (#3866)

* add rank args in lora finetune

* del network_alpha
This commit is contained in:
takuoko
2023-06-29 20:39:59 +09:00
committed by GitHub
parent 49949f321d
commit cdf2ae8a84

View File

@@ -343,6 +343,12 @@ def parse_args():
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -464,7 +470,11 @@ def main():
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)
unet.set_attn_processor(lora_attn_procs)