mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add rank argument to train_dreambooth_lora_sdxl.py (#4343)
* Add rank argument to train_dreambooth_lora_sdxl.py * Update train_dreambooth_lora_sdxl.py
This commit is contained in:
@@ -402,6 +402,12 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -767,7 +773,9 @@ def main(args):
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
)
|
||||
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
module = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
|
||||
)
|
||||
unet_lora_attn_procs[name] = module
|
||||
unet_lora_parameters.extend(module.parameters())
|
||||
|
||||
@@ -777,8 +785,12 @@ def main(args):
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32)
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_one, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
|
||||
text_encoder_two, dtype=torch.float32, rank=args.rank
|
||||
)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
|
||||
Reference in New Issue
Block a user