diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2f4d2ab6dc..7e04aa0ac8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -11,5 +11,5 @@ from .models.unet_ldm import UNetLDMModel from .models.unet_grad_tts import UNetGradTTSModel from .pipeline_utils import DiffusionPipeline from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM -from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler +from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler, GradTTSScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 5e9dcaf64e..9e1cd3edc8 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,4 +20,5 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_pndm import PNDMScheduler +from .scheduling_grad_tts import GradTTSScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py new file mode 100644 index 0000000000..11a557a3e1 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -0,0 +1,52 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class GradTTSScheduler(SchedulerMixin, ConfigMixin): + def __init__( + self, + timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + tensor_format="np", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_start=beta_start, + beta_end=beta_end, + ) + self.timesteps = int(timesteps) + + self.set_format(tensor_format=tensor_format) + + def sample_noise(self, timestep): + noise = self.beta_start + (self.beta_end - self.beta_start) * timestep + return noise + + def step(self, xt, residual, mu, h, timestep): + noise_t = self.sample_noise(timestep) + dxt = 0.5 * (mu - xt - residual) + dxt = dxt * noise_t * h + xt = xt - dxt + return xt + + def __len__(self): + return self.timesteps