mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Pytorch] add dep. warning for pytorch schedulers (#651)
* add dep. warning for schedulers * fix format
This commit is contained in:
@@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
if "tensor_format" in kwargs:
|
||||
warnings.warn(
|
||||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if "tensor_format" in kwargs:
|
||||
warnings.warn(
|
||||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_churn: float = 80,
|
||||
s_min: float = 0.05,
|
||||
s_max: float = 50,
|
||||
**kwargs,
|
||||
):
|
||||
if "tensor_format" in kwargs:
|
||||
warnings.warn(
|
||||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps: int = None
|
||||
self.timesteps: np.ndarray = None
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if "tensor_format" in kwargs:
|
||||
warnings.warn(
|
||||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
|
||||
@@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
if "tensor_format" in kwargs:
|
||||
warnings.warn(
|
||||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
|
||||
@@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_max: float = 1348.0,
|
||||
sampling_eps: float = 1e-5,
|
||||
correct_steps: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
if "tensor_format" in kwargs:
|
||||
warnings.warn(
|
||||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# setable values
|
||||
self.timesteps = None
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
|
||||
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3):
|
||||
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs):
|
||||
if "tensor_format" in kwargs:
|
||||
warnings.warn(
|
||||
"`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this argument.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.sigmas = None
|
||||
self.discrete_sigmas = None
|
||||
self.timesteps = None
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
@@ -41,3 +42,12 @@ class SchedulerMixin:
|
||||
"""
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
|
||||
def set_format(self, tensor_format="pt"):
|
||||
warnings.warn(
|
||||
"The method `set_format` is deprecated and will be removed in version `0.5.0`."
|
||||
"If you're running your code in PyTorch, you can safely remove this function as the schedulers"
|
||||
"are always in Pytorch",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user