mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
323 lines
13 KiB
Python
323 lines
13 KiB
Python
import copy
|
|
import os
|
|
import random
|
|
from typing import Any, Dict, Iterable, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .utils import deprecate
|
|
|
|
|
|
def enable_full_determinism(seed: int):
|
|
"""
|
|
Helper function for reproducible behavior during distributed training. See
|
|
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
|
|
"""
|
|
# set seed first
|
|
set_seed(seed)
|
|
|
|
# Enable PyTorch deterministic mode. This potentially requires either the environment
|
|
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
|
|
# depending on the CUDA version, so we set them both here
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
|
torch.use_deterministic_algorithms(True)
|
|
|
|
# Enable CUDNN deterministic mode
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def set_seed(seed: int):
|
|
"""
|
|
Args:
|
|
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
|
seed (`int`): The seed to set.
|
|
"""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
# ^^ safe to call this function even if cuda is not available
|
|
|
|
|
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
|
class EMAModel:
|
|
"""
|
|
Exponential Moving Average of models weights
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
parameters: Iterable[torch.nn.Parameter],
|
|
decay: float = 0.9999,
|
|
min_decay: float = 0.0,
|
|
update_after_step: int = 0,
|
|
use_ema_warmup: bool = False,
|
|
inv_gamma: Union[float, int] = 1.0,
|
|
power: Union[float, int] = 2 / 3,
|
|
model_cls: Optional[Any] = None,
|
|
model_config: Dict[str, Any] = None,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
|
decay (float): The decay factor for the exponential moving average.
|
|
min_decay (float): The minimum decay factor for the exponential moving average.
|
|
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
|
use_ema_warmup (bool): Whether to use EMA warmup.
|
|
inv_gamma (float):
|
|
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
|
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
|
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
|
weights will be stored on CPU.
|
|
|
|
@crowsonkb's notes on EMA Warmup:
|
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
|
at 215.4k steps).
|
|
"""
|
|
|
|
if isinstance(parameters, torch.nn.Module):
|
|
deprecation_message = (
|
|
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
|
|
"Please pass the parameters of the module instead."
|
|
)
|
|
deprecate(
|
|
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
|
|
"1.0.0",
|
|
deprecation_message,
|
|
standard_warn=False,
|
|
)
|
|
parameters = parameters.parameters()
|
|
|
|
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
|
|
use_ema_warmup = True
|
|
|
|
if kwargs.get("max_value", None) is not None:
|
|
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
|
|
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
|
|
decay = kwargs["max_value"]
|
|
|
|
if kwargs.get("min_value", None) is not None:
|
|
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
|
|
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
|
|
min_decay = kwargs["min_value"]
|
|
|
|
parameters = list(parameters)
|
|
self.shadow_params = [p.clone().detach() for p in parameters]
|
|
|
|
if kwargs.get("device", None) is not None:
|
|
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
|
|
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
|
|
self.to(device=kwargs["device"])
|
|
|
|
self.temp_stored_params = None
|
|
|
|
self.decay = decay
|
|
self.min_decay = min_decay
|
|
self.update_after_step = update_after_step
|
|
self.use_ema_warmup = use_ema_warmup
|
|
self.inv_gamma = inv_gamma
|
|
self.power = power
|
|
self.optimization_step = 0
|
|
self.cur_decay_value = None # set in `step()`
|
|
|
|
self.model_cls = model_cls
|
|
self.model_config = model_config
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, path, model_cls) -> "EMAModel":
|
|
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
|
|
model = model_cls.from_pretrained(path)
|
|
|
|
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
|
|
|
|
ema_model.load_state_dict(ema_kwargs)
|
|
return ema_model
|
|
|
|
def save_pretrained(self, path):
|
|
if self.model_cls is None:
|
|
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
|
|
|
if self.model_config is None:
|
|
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
|
|
|
model = self.model_cls.from_config(self.model_config)
|
|
state_dict = self.state_dict()
|
|
state_dict.pop("shadow_params", None)
|
|
|
|
model.register_to_config(**state_dict)
|
|
self.copy_to(model.parameters())
|
|
model.save_pretrained(path)
|
|
|
|
def get_decay(self, optimization_step: int) -> float:
|
|
"""
|
|
Compute the decay factor for the exponential moving average.
|
|
"""
|
|
step = max(0, optimization_step - self.update_after_step - 1)
|
|
|
|
if step <= 0:
|
|
return 0.0
|
|
|
|
if self.use_ema_warmup:
|
|
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
|
else:
|
|
cur_decay_value = (1 + step) / (10 + step)
|
|
|
|
cur_decay_value = min(cur_decay_value, self.decay)
|
|
# make sure decay is not smaller than min_decay
|
|
cur_decay_value = max(cur_decay_value, self.min_decay)
|
|
return cur_decay_value
|
|
|
|
@torch.no_grad()
|
|
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
|
if isinstance(parameters, torch.nn.Module):
|
|
deprecation_message = (
|
|
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
|
|
"Please pass the parameters of the module instead."
|
|
)
|
|
deprecate(
|
|
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
|
|
"1.0.0",
|
|
deprecation_message,
|
|
standard_warn=False,
|
|
)
|
|
parameters = parameters.parameters()
|
|
|
|
parameters = list(parameters)
|
|
|
|
self.optimization_step += 1
|
|
|
|
# Compute the decay factor for the exponential moving average.
|
|
decay = self.get_decay(self.optimization_step)
|
|
self.cur_decay_value = decay
|
|
one_minus_decay = 1 - decay
|
|
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
if param.requires_grad:
|
|
s_param.sub_(one_minus_decay * (s_param - param))
|
|
else:
|
|
s_param.copy_(param)
|
|
|
|
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
|
"""
|
|
Copy current averaged parameters into given collection of parameters.
|
|
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored moving averages. If `None`, the parameters with which this
|
|
`ExponentialMovingAverage` was initialized will be used.
|
|
"""
|
|
parameters = list(parameters)
|
|
for s_param, param in zip(self.shadow_params, parameters):
|
|
param.data.copy_(s_param.to(param.device).data)
|
|
|
|
def to(self, device=None, dtype=None) -> None:
|
|
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
|
|
|
Args:
|
|
device: like `device` argument to `torch.Tensor.to`
|
|
"""
|
|
# .to() on the tensors handles None correctly
|
|
self.shadow_params = [
|
|
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
|
for p in self.shadow_params
|
|
]
|
|
|
|
def state_dict(self) -> dict:
|
|
r"""
|
|
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
|
checkpointing to save the ema state dict.
|
|
"""
|
|
# Following PyTorch conventions, references to tensors are returned:
|
|
# "returns a reference to the state and not its copy!" -
|
|
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
|
return {
|
|
"decay": self.decay,
|
|
"min_decay": self.min_decay,
|
|
"optimization_step": self.optimization_step,
|
|
"update_after_step": self.update_after_step,
|
|
"use_ema_warmup": self.use_ema_warmup,
|
|
"inv_gamma": self.inv_gamma,
|
|
"power": self.power,
|
|
"shadow_params": self.shadow_params,
|
|
}
|
|
|
|
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
|
r"""
|
|
Args:
|
|
Save the current parameters for restoring later.
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
temporarily stored.
|
|
"""
|
|
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
|
|
|
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
|
r"""
|
|
Args:
|
|
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
|
|
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
|
validation (or model saving), use this to restore the former parameters.
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored parameters. If `None`, the parameters with which this
|
|
`ExponentialMovingAverage` was initialized will be used.
|
|
"""
|
|
if self.temp_stored_params is None:
|
|
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
|
|
for c_param, param in zip(self.temp_stored_params, parameters):
|
|
param.data.copy_(c_param.data)
|
|
|
|
# Better memory-wise.
|
|
self.temp_stored_params = None
|
|
|
|
def load_state_dict(self, state_dict: dict) -> None:
|
|
r"""
|
|
Args:
|
|
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
|
ema state dict.
|
|
state_dict (dict): EMA state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
# deepcopy, to be consistent with module API
|
|
state_dict = copy.deepcopy(state_dict)
|
|
|
|
self.decay = state_dict.get("decay", self.decay)
|
|
if self.decay < 0.0 or self.decay > 1.0:
|
|
raise ValueError("Decay must be between 0 and 1")
|
|
|
|
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
|
if not isinstance(self.min_decay, float):
|
|
raise ValueError("Invalid min_decay")
|
|
|
|
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
|
if not isinstance(self.optimization_step, int):
|
|
raise ValueError("Invalid optimization_step")
|
|
|
|
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
|
if not isinstance(self.update_after_step, int):
|
|
raise ValueError("Invalid update_after_step")
|
|
|
|
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
|
if not isinstance(self.use_ema_warmup, bool):
|
|
raise ValueError("Invalid use_ema_warmup")
|
|
|
|
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
|
if not isinstance(self.inv_gamma, (float, int)):
|
|
raise ValueError("Invalid inv_gamma")
|
|
|
|
self.power = state_dict.get("power", self.power)
|
|
if not isinstance(self.power, (float, int)):
|
|
raise ValueError("Invalid power")
|
|
|
|
shadow_params = state_dict.get("shadow_params", None)
|
|
if shadow_params is not None:
|
|
self.shadow_params = shadow_params
|
|
if not isinstance(self.shadow_params, list):
|
|
raise ValueError("shadow_params must be a list")
|
|
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
|
raise ValueError("shadow_params must all be Tensors")
|