1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] make lora alpha and dropout configurable (#11467)

* add lora_alpha and lora_dropout

* Apply style fixes

* add lora_alpha and lora_dropout

* Apply style fixes

* revert lora_alpha until #11324 is merged

* Apply style fixes

* empty commit

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Linoy Tsaban
2025-05-08 11:54:50 +03:00
committed by GitHub
parent c5c34a4591
commit 66e50d4e24
10 changed files with 58 additions and 3 deletions

View File

@@ -430,6 +430,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1554,6 +1557,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1562,6 +1566,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

View File

@@ -658,6 +658,8 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--use_dora",
action="store_true",
@@ -1248,6 +1250,7 @@ def main(args):
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@@ -1260,6 +1263,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],

View File

@@ -767,6 +767,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--use_dora",
action="store_true",
@@ -1558,6 +1561,7 @@ def main(args):
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1570,6 +1574,7 @@ def main(args):
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

View File

@@ -524,6 +524,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--image_interpolation_mode",
type=str,
@@ -932,6 +935,7 @@ def main(args):
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
@@ -942,6 +946,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

View File

@@ -358,6 +358,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1236,6 +1239,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1244,6 +1248,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

View File

@@ -417,6 +417,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1161,6 +1164,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)

View File

@@ -328,6 +328,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1023,6 +1026,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)

View File

@@ -323,6 +323,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1021,6 +1024,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)

View File

@@ -367,6 +367,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1264,6 +1267,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1273,6 +1277,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)

View File

@@ -659,6 +659,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--use_dora",
action="store_true",
@@ -1199,10 +1202,11 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
def get_lora_config(rank, use_dora, target_modules):
def get_lora_config(rank, dropout, use_dora, target_modules):
base_config = {
"r": rank,
"lora_alpha": rank,
"lora_dropout": dropout,
"init_lora_weights": "gaussian",
"target_modules": target_modules,
}
@@ -1218,14 +1222,24 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
unet_lora_config = get_lora_config(
rank=args.rank,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=unet_target_modules,
)
unet.add_adapter(unet_lora_config)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
text_lora_config = get_lora_config(
rank=args.rank,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=text_target_modules,
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)