From 1be7df0205d761c50213aa70967a74141dd64849 Mon Sep 17 00:00:00 2001 From: erkams Date: Thu, 9 Feb 2023 11:45:11 +0100 Subject: [PATCH] [LoRA] Freezing the model weights (#2245) * [LoRA] Freezing the model weights Freeze the model weights since we don't need to calculate grads for them. * Apply suggestions from code review Co-authored-by: Patrick von Platen * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen Co-authored-by: Suraj Patil --- examples/text_to_image/train_text_to_image_lora.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b56e0dca53..a3c5bef73a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -415,7 +415,12 @@ def main(): unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - + # freeze parameters of models to save more memory + unet.requires_grad_(False) + vae.requires_grad_(False) + + text_encoder.requires_grad_(False) + # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32