1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Pytorch] add dep. warning for pytorch schedulers (#651)

* add dep. warning for schedulers

* fix format
This commit is contained in:
Kashif Rasul
2022-09-27 18:39:34 +02:00
committed by GitHub
parent 3304538229
commit 85494e8818
8 changed files with 69 additions and 1 deletions

View File

@@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":

View File

@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
elif beta_schedule == "linear":

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
# setable values
self.num_inference_steps: int = None
self.timesteps: np.ndarray = None

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":

View File

@@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
if beta_schedule == "linear":

View File

@@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
**kwargs,
):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
# setable values
self.timesteps = None

View File

@@ -17,6 +17,7 @@
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
import math
import warnings
import torch
@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
"""
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
if "tensor_format" in kwargs:
warnings.warn(
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this argument.",
DeprecationWarning,
)
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
import torch
@@ -41,3 +42,12 @@ class SchedulerMixin:
"""
config_name = SCHEDULER_CONFIG_NAME
def set_format(self, tensor_format="pt"):
warnings.warn(
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
"are always in Pytorch",
DeprecationWarning,
)
return self