diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index ccff870609..0d9e285e05 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -120,7 +120,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 7f8988fdfd..cc17cee4c8 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -112,7 +113,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) elif beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 5826858fae..e6e5300e73 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -13,6 +13,7 @@ # limitations under the License. +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -86,7 +87,15 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): s_churn: float = 80, s_min: float = 0.05, s_max: float = 50, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + # setable values self.num_inference_steps: int = None self.timesteps: np.ndarray = None diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 6167af5ad4..6d8db7682d 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -74,7 +75,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index ade223e2fb..d9e430c4a6 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -100,7 +100,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) if beta_schedule == "linear": diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 7b06ae16c5..a549654c3b 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -76,7 +76,15 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): sigma_max: float = 1348.0, sampling_eps: float = 1e-5, correct_steps: int = 1, + **kwargs, ): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) + # setable values self.timesteps = None diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index 2f9821579c..daea743873 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -17,6 +17,7 @@ # TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit import math +import warnings import torch @@ -40,7 +41,13 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): """ @register_to_config - def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3): + def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, **kwargs): + if "tensor_format" in kwargs: + warnings.warn( + "`tensor_format` is deprecated as an argument and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this argument.", + DeprecationWarning, + ) self.sigmas = None self.discrete_sigmas = None self.timesteps = None diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 29bf982f6a..1cc1d94414 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass import torch @@ -41,3 +42,12 @@ class SchedulerMixin: """ config_name = SCHEDULER_CONFIG_NAME + + def set_format(self, tensor_format="pt"): + warnings.warn( + "The method `set_format` is deprecated and will be removed in version `0.5.0`." + "If you're running your code in PyTorch, you can safely remove this function as the schedulers" + "are always in Pytorch", + DeprecationWarning, + ) + return self