1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add 'rank' parameter to Dreambooth LoRA training script (#3945)

This commit is contained in:
Batuhan Taskaya
2023-07-07 07:56:10 -04:00
committed by GitHub
parent 03d829d59e
commit 04ddad484e

View File

@@ -436,6 +436,12 @@ def parse_args(input_args=None):
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -845,7 +851,9 @@ def main(args):
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
unet_lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
)
unet.set_attn_processor(unet_lora_attn_procs)
@@ -860,7 +868,9 @@ def main(args):
for name, module in text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_proj.out_features, cross_attention_dim=None
hidden_size=module.out_proj.out_features,
cross_attention_dim=None,
rank=args.rank,
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(