mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Training Utils] create a utility for casting the lora params during training. (#6553)
create a utility for casting the lora params during training.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -121,6 +121,16 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
return lora_state_dict
|
||||
|
||||
|
||||
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
|
||||
|
||||
def _set_state_dict_into_text_encoder(
|
||||
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user