1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[LoRA] Freezing the model weights (#2245)

* [LoRA] Freezing the model weights

Freeze the model weights since we don't need to calculate grads for them.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
erkams
2023-02-09 11:45:11 +01:00
committed by GitHub
parent 62a15cec6e
commit 1be7df0205

View File

@@ -415,7 +415,12 @@ def main():
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32