mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add 'rank' parameter to Dreambooth LoRA training script (#3945)
This commit is contained in:
@@ -436,6 +436,12 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -845,7 +851,9 @@ def main(args):
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
unet_lora_attn_procs[name] = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
rank=args.rank,
|
||||
)
|
||||
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
@@ -860,7 +868,9 @@ def main(args):
|
||||
for name, module in text_encoder.named_modules():
|
||||
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
||||
text_lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=module.out_proj.out_features, cross_attention_dim=None
|
||||
hidden_size=module.out_proj.out_features,
|
||||
cross_attention_dim=None,
|
||||
rank=args.rank,
|
||||
)
|
||||
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
|
||||
temp_pipeline = DiffusionPipeline.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user