1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Training Utils] create a utility for casting the lora params during training. (#6553)

create a utility for casting the lora params during training.
This commit is contained in:
Sayak Paul
2024-01-15 13:51:13 +05:30
committed by GitHub
parent 79df50388d
commit a080f0d3a2
5 changed files with 26 additions and 28 deletions

View File

@@ -1,7 +1,7 @@
import contextlib
import copy
import random
from typing import Any, Dict, Iterable, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union
import numpy as np
import torch
@@ -121,6 +121,16 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
return lora_state_dict
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
def _set_state_dict_into_text_encoder(
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
):