1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow saving trained betas (#1468)

This commit is contained in:
Patrick von Platen
2022-11-30 10:05:51 +01:00
committed by GitHub
parent 0b7225e918
commit 110ffe2589
11 changed files with 55 additions and 30 deletions

View File

@@ -24,6 +24,8 @@ import re
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union
import numpy as np
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
@@ -502,6 +504,12 @@ class ConfigMixin:
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):

View File

@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
@@ -139,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
@@ -130,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -127,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
@@ -147,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -77,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -78,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -53,10 +53,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -13,8 +13,9 @@
# limitations under the License.
import math
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
@@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
order = 1
@register_to_config
def __init__(self, num_train_timesteps: int = 1000):
def __init__(
self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
):
# set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps)
@@ -67,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
steps = torch.cat([steps, torch.tensor([0.0])])
self.betas = torch.sin(steps * math.pi / 2) ** 2
if self.config.trained_betas is not None:
self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
else:
self.betas = torch.sin(steps * math.pi / 2) ** 2
self.alphas = (1.0 - self.betas**2) ** 0.5
timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -77,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
@@ -99,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":

View File

@@ -584,6 +584,20 @@ class SchedulerCommonTest(unittest.TestCase):
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)
def test_trained_betas(self):
for scheduler_class in self.scheduler_classes:
if scheduler_class == VQDiffusionScheduler:
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1]))
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_pretrained(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
@@ -1423,7 +1437,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}
config.update(**kwargs)
@@ -1505,7 +1518,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}
config.update(**kwargs)
@@ -1596,7 +1608,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}
config.update(**kwargs)
@@ -1905,7 +1916,6 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}
config.update(**kwargs)