mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] make lora alpha and dropout configurable (#11467)
* add lora_alpha and lora_dropout * Apply style fixes * add lora_alpha and lora_dropout * Apply style fixes * revert lora_alpha until #11324 is merged * Apply style fixes * empty commit --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -430,6 +430,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1554,6 +1557,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1562,6 +1566,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -658,6 +658,8 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
@@ -1248,6 +1250,7 @@ def main(args):
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
@@ -1260,6 +1263,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
|
||||
@@ -767,6 +767,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
@@ -1558,6 +1561,7 @@ def main(args):
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1570,6 +1574,7 @@ def main(args):
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -524,6 +524,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--image_interpolation_mode",
|
||||
type=str,
|
||||
@@ -932,6 +935,7 @@ def main(args):
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
|
||||
)
|
||||
@@ -942,6 +946,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -358,6 +358,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1236,6 +1239,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1244,6 +1248,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -417,6 +417,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1161,6 +1164,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
|
||||
@@ -328,6 +328,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1023,6 +1026,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
|
||||
@@ -323,6 +323,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1021,6 +1024,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
|
||||
@@ -367,6 +367,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1264,6 +1267,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
)
|
||||
@@ -1273,6 +1277,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
|
||||
@@ -659,6 +659,9 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
action="store_true",
|
||||
@@ -1199,10 +1202,11 @@ def main(args):
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
def get_lora_config(rank, use_dora, target_modules):
|
||||
def get_lora_config(rank, dropout, use_dora, target_modules):
|
||||
base_config = {
|
||||
"r": rank,
|
||||
"lora_alpha": rank,
|
||||
"lora_dropout": dropout,
|
||||
"init_lora_weights": "gaussian",
|
||||
"target_modules": target_modules,
|
||||
}
|
||||
@@ -1218,14 +1222,24 @@ def main(args):
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
|
||||
unet_lora_config = get_lora_config(
|
||||
rank=args.rank,
|
||||
dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
target_modules=unet_target_modules,
|
||||
)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
|
||||
text_lora_config = get_lora_config(
|
||||
rank=args.rank,
|
||||
dropout=args.lora_dropout,
|
||||
use_dora=args.use_dora,
|
||||
target_modules=text_target_modules,
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user