mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow saving trained betas (#1468)
This commit is contained in:
committed by
GitHub
parent
0b7225e918
commit
110ffe2589
@@ -24,6 +24,8 @@ import re
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from requests import HTTPError
|
||||
@@ -502,6 +504,12 @@ class ConfigMixin:
|
||||
config_dict["_class_name"] = self.__class__.__name__
|
||||
config_dict["_diffusers_version"] = __version__
|
||||
|
||||
def to_json_saveable(value):
|
||||
if isinstance(value, np.ndarray):
|
||||
value = value.tolist()
|
||||
return value
|
||||
|
||||
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -123,7 +123,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
@@ -139,7 +139,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -115,7 +115,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
@@ -130,7 +130,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -127,7 +127,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
@@ -147,7 +147,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -77,10 +77,10 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -78,11 +78,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -53,10 +53,10 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.00085, # sensible defaults
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps: int = 1000):
|
||||
def __init__(
|
||||
self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
|
||||
):
|
||||
# set `betas`, `alphas`, `timesteps`
|
||||
self.set_timesteps(num_train_timesteps)
|
||||
|
||||
@@ -67,7 +70,11 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
|
||||
steps = torch.cat([steps, torch.tensor([0.0])])
|
||||
|
||||
self.betas = torch.sin(steps * math.pi / 2) ** 2
|
||||
if self.config.trained_betas is not None:
|
||||
self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
|
||||
else:
|
||||
self.betas = torch.sin(steps * math.pi / 2) ** 2
|
||||
|
||||
self.alphas = (1.0 - self.betas**2) ** 0.5
|
||||
|
||||
timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -77,10 +77,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -99,13 +99,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
|
||||
@@ -584,6 +584,20 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
def test_trained_betas(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
continue
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1]))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_pretrained(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_pretrained(tmpdirname)
|
||||
|
||||
assert scheduler.betas.tolist() == new_scheduler.betas.tolist()
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDPMScheduler,)
|
||||
@@ -1423,7 +1437,6 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -1505,7 +1518,6 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -1596,7 +1608,6 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
@@ -1905,7 +1916,6 @@ class HeunDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"trained_betas": None,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user