mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Use accelerate save & loading hooks to have better checkpoint structure (#2048)
* better accelerated saving * up * finish * finish * uP * up * up * fix * Apply suggestions from code review * correct ema * Remove @ * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update docs/source/en/training/dreambooth.mdx Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
e619db24be
commit
f5ccffecf7
@@ -127,7 +127,30 @@ This would be a good opportunity to tweak some of your hyperparameters if you wi
|
||||
|
||||
Saved checkpoints are stored in a format suitable for resuming training. They not only include the model weights, but also the state of the optimizer, data loaders and learning rate.
|
||||
|
||||
You can use a checkpoint for inference, but first you need to convert it to an inference pipeline. This is how you could do it:
|
||||
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
|
||||
inference from an intermediate checkpoint.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
import torch
|
||||
|
||||
# Load the pipeline with the same arguments (model, revision) that were used for training
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/unet")
|
||||
|
||||
# if you have trained with `--args.train_text_encoder` make sure to also load the text encoder
|
||||
text_encoder = CLIPTextModel.from_pretrained("/sddata/dreambooth/daruma-v2-1/checkpoint-100/text_encoder")
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(model_id, unet=unet, text_encoder=text_encoder, dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
# Perform inference, or save, or push to the hub
|
||||
pipeline.save_pretrained("dreambooth-pipeline")
|
||||
```
|
||||
|
||||
If you have installed `"accelerate<0.16.0"` you need to first convert it to an inference pipeline. This is how you could do it:
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
@@ -271,6 +294,10 @@ accelerate launch train_dreambooth.py \
|
||||
|
||||
Once you have trained a model, inference can be done using the `StableDiffusionPipeline`, by simply indicating the path where the model was saved. Make sure that your prompts include the special `identifier` used during training (`sks` in the previous examples).
|
||||
|
||||
**Note**: If you have installed `"accelerate>=0.16.0"` you can use the following code to run
|
||||
inference from an intermediate checkpoint.
|
||||
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
@@ -284,4 +311,4 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
image.save("dog-bucket.png")
|
||||
```
|
||||
|
||||
You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).
|
||||
You may also run inference from [any of the saved training checkpoints](#performing-inference-using-a-saved-checkpoint).
|
||||
|
||||
@@ -28,6 +28,7 @@ import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import accelerate
|
||||
import diffusers
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
@@ -38,6 +39,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -606,6 +608,37 @@ def main(args):
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
for model in models:
|
||||
sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
|
||||
model.save_pretrained(os.path.join(output_dir, sub_dir))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
while len(models) > 0:
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
if type(model) == type(text_encoder):
|
||||
# load transformers style into model
|
||||
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
|
||||
model.config = load_model.config
|
||||
else:
|
||||
# load diffusers style into model
|
||||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
||||
model.register_to_config(**load_model.config)
|
||||
|
||||
model.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
@@ -26,6 +26,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
import diffusers
|
||||
import transformers
|
||||
@@ -36,9 +37,10 @@ from datasets import load_dataset
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils import check_min_version, deprecate
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
@@ -319,6 +321,16 @@ dataset_name_mapping = {
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.non_ema_revision is not None:
|
||||
deprecate(
|
||||
"non_ema_revision!=None",
|
||||
"0.15.0",
|
||||
message=(
|
||||
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
|
||||
" use `--variant=non_ema` instead."
|
||||
),
|
||||
)
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
@@ -396,6 +408,39 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
|
||||
ema_unet.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
for i in range(len(models)):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
# load diffusers style into model
|
||||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
||||
model.register_to_config(**load_model.config)
|
||||
|
||||
model.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
@@ -552,8 +597,9 @@ def main():
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
accelerator.register_for_checkpointing(ema_unet)
|
||||
ema_unet.to(accelerator.device)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
@@ -566,8 +612,6 @@ def main():
|
||||
# Move text_encode and vae to gpu and cast to weight_dtype
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.use_ema:
|
||||
ema_unet.to(accelerator.device)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
import diffusers
|
||||
from accelerate import Accelerator
|
||||
@@ -19,6 +20,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -271,6 +273,40 @@ def main(args):
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
|
||||
ema_model.load_state_dict(load_model.state_dict())
|
||||
ema_model.to(accelerator.device)
|
||||
del load_model
|
||||
|
||||
for i in range(len(models)):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
# load diffusers style into model
|
||||
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
|
||||
model.register_to_config(**load_model.config)
|
||||
|
||||
model.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -336,6 +372,8 @@ def main(args):
|
||||
use_ema_warmup=True,
|
||||
inv_gamma=args.ema_inv_gamma,
|
||||
power=args.ema_power,
|
||||
model_cls=UNet2DModel,
|
||||
model_config=model.config,
|
||||
)
|
||||
|
||||
# Initialize the scheduler
|
||||
@@ -411,7 +449,6 @@ def main(args):
|
||||
)
|
||||
|
||||
if args.use_ema:
|
||||
accelerator.register_for_checkpointing(ema_model)
|
||||
ema_model.to(accelerator.device)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import os
|
||||
import random
|
||||
from typing import Iterable, Union
|
||||
from typing import Any, Dict, Iterable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -57,6 +57,8 @@ class EMAModel:
|
||||
use_ema_warmup: bool = False,
|
||||
inv_gamma: Union[float, int] = 1.0,
|
||||
power: Union[float, int] = 2 / 3,
|
||||
model_cls: Optional[Any] = None,
|
||||
model_config: Dict[str, Any] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -123,6 +125,35 @@ class EMAModel:
|
||||
self.power = power
|
||||
self.optimization_step = 0
|
||||
|
||||
self.model_cls = model_cls
|
||||
self.model_config = model_config
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path, model_cls) -> "EMAModel":
|
||||
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
|
||||
model = model_cls.from_pretrained(path)
|
||||
|
||||
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
|
||||
|
||||
ema_model.load_state_dict(ema_kwargs)
|
||||
return ema_model
|
||||
|
||||
def save_pretrained(self, path):
|
||||
if self.model_cls is None:
|
||||
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
||||
|
||||
if self.model_config is None:
|
||||
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
||||
|
||||
model = self.model_cls.from_config(self.model_config)
|
||||
state_dict = self.state_dict()
|
||||
state_dict.pop("shadow_params", None)
|
||||
state_dict.pop("collected_params", None)
|
||||
|
||||
model.register_to_config(**state_dict)
|
||||
self.copy_to(model.parameters())
|
||||
model.save_pretrained(path)
|
||||
|
||||
def get_decay(self, optimization_step: int) -> float:
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
@@ -184,7 +215,7 @@ class EMAModel:
|
||||
"""
|
||||
parameters = list(parameters)
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
param.data.copy_(s_param.data)
|
||||
param.data.copy_(s_param.to(param.device).data)
|
||||
|
||||
def to(self, device=None, dtype=None) -> None:
|
||||
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||
@@ -257,13 +288,15 @@ class EMAModel:
|
||||
if not isinstance(self.power, (float, int)):
|
||||
raise ValueError("Invalid power")
|
||||
|
||||
self.shadow_params = state_dict["shadow_params"]
|
||||
if not isinstance(self.shadow_params, list):
|
||||
raise ValueError("shadow_params must be a list")
|
||||
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
||||
raise ValueError("shadow_params must all be Tensors")
|
||||
shadow_params = state_dict.get("shadow_params", None)
|
||||
if shadow_params is not None:
|
||||
self.shadow_params = shadow_params
|
||||
if not isinstance(self.shadow_params, list):
|
||||
raise ValueError("shadow_params must be a list")
|
||||
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
||||
raise ValueError("shadow_params must all be Tensors")
|
||||
|
||||
self.collected_params = state_dict["collected_params"]
|
||||
self.collected_params = state_dict.get("collected_params", None)
|
||||
if self.collected_params is not None:
|
||||
if not isinstance(self.collected_params, list):
|
||||
raise ValueError("collected_params must be a list")
|
||||
|
||||
Reference in New Issue
Block a user