diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 576f0646f9..ceff63c4b9 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,7 +16,30 @@ import math import numpy as np from ..configuration_utils import ConfigMixin -from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule +from .scheduling_utils import SchedulerMixin + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) class DDIMScheduler(SchedulerMixin, ConfigMixin): @@ -43,13 +66,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ) if beta_schedule == "linear": - self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule - self.betas = betas_for_alpha_bar( - timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) + self.betas = betas_for_alpha_bar(timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") @@ -59,28 +79,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) - # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - # TODO(PVP) - check how much of these is actually necessary! - # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... - # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 - # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - # if variance_type == "fixed_small": - # log_variance = torch.log(variance.clamp(min=1e-20)) - # elif variance_type == "fixed_large": - # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) - # - # - # self.register_buffer("log_variance", log_variance.to(torch.float32)) - - # def rescale_betas(self, num_timesteps): - # # GLIDE scaling - # if self.beta_schedule == "linear": - # scale = self.timesteps / num_timesteps - # self.betas = linear_beta_schedule( - # num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale - # ) - # self.alphas = 1.0 - self.betas - # self.alphas_cumprod = np.cumprod(self.alphas, axis=0) def get_timestep_values(self): return self.config.timestep_values diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 012b5a5bee..172f5647e9 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,7 +16,30 @@ import math import numpy as np from ..configuration_utils import ConfigMixin -from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule +from .scheduling_utils import SchedulerMixin + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) class DDPMScheduler(SchedulerMixin, ConfigMixin): @@ -47,13 +70,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): if trained_betas is not None: self.betas = np.asarray(trained_betas) elif beta_schedule == "linear": - self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule - self.betas = betas_for_alpha_bar( - timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) + self.betas = betas_for_alpha_bar(timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 686c31140c..67a7105b57 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -11,12 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import math -import numpy as np - from ..configuration_utils import ConfigMixin -from .scheduling_utils import SchedulerMixin, betas_for_alpha_bar, linear_beta_schedule +from .scheduling_utils import SchedulerMixin + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas, dtype=np.float32) class PNDMScheduler(SchedulerMixin, ConfigMixin): @@ -37,13 +59,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): ) if beta_schedule == "linear": - self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) + self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule - self.betas = betas_for_alpha_bar( - timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) + self.betas = betas_for_alpha_bar(timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 6cbc0212e2..a6f317852d 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -18,30 +18,6 @@ import torch SCHEDULER_CONFIG_NAME = "scheduler_config.json" -def linear_beta_schedule(timesteps, beta_start, beta_end): - return np.linspace(beta_start, beta_end, timesteps, dtype=np.float32) - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas, dtype=np.float32) - - class SchedulerMixin: config_name = SCHEDULER_CONFIG_NAME