From 4188f3063afae321b45902b0a95ac4dea1004fcb Mon Sep 17 00:00:00 2001 From: Levi McCallum Date: Thu, 3 Aug 2023 10:57:30 -0700 Subject: [PATCH] Add rank argument to train_dreambooth_lora_sdxl.py (#4343) * Add rank argument to train_dreambooth_lora_sdxl.py * Update train_dreambooth_lora_sdxl.py --- .../dreambooth/train_dreambooth_lora_sdxl.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0383ab4b99..423dee1568 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -402,6 +402,12 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) if input_args is not None: args = parser.parse_args(input_args) @@ -767,7 +773,9 @@ def main(args): lora_attn_processor_class = ( LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor ) - module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + module = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank + ) unet_lora_attn_procs[name] = module unet_lora_parameters.extend(module.parameters()) @@ -777,8 +785,12 @@ def main(args): # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32) + text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( + text_encoder_one, dtype=torch.float32, rank=args.rank + ) + text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( + text_encoder_two, dtype=torch.float32, rank=args.rank + ) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir):