mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Training Utils] create a utility for casting the lora params during training. (#6553)
create a utility for casting the lora params during training.
This commit is contained in:
@@ -51,7 +51,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import resolve_interpolation_mode
|
||||
from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -860,10 +860,8 @@ def main(args):
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in unet.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(unet, dtype=torch.float32)
|
||||
|
||||
# Also move the alpha and sigma noise schedules to accelerator.device.
|
||||
alpha_schedule = alpha_schedule.to(accelerator.device)
|
||||
|
||||
@@ -53,7 +53,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
@@ -1086,11 +1086,8 @@ def main(args):
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@@ -1110,11 +1107,9 @@ def main(args):
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -466,10 +466,8 @@ def main():
|
||||
# Add adapter and make sure the trainable params are in float32.
|
||||
unet.add_adapter(unet_lora_config)
|
||||
if args.mixed_precision == "fp16":
|
||||
for param in unet.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(unet, dtype=torch.float32)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
|
||||
@@ -51,7 +51,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -634,11 +634,8 @@ def main(args):
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -121,6 +121,16 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
return lora_state_dict
|
||||
|
||||
|
||||
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
|
||||
|
||||
def _set_state_dict_into_text_encoder(
|
||||
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user