mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
trained_betas ignored in some schedulers (#635)
* correcting the beta value assignment * updating DDIM and LMSDiscreteFlax schedulers * bringing back the changes that were lost as part of main branch merge
This commit is contained in:
@@ -131,7 +131,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
|
||||
@@ -86,7 +86,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
|
||||
@@ -74,7 +74,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
|
||||
@@ -111,7 +111,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
|
||||
@@ -132,7 +132,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = jnp.asarray(trained_betas)
|
||||
if beta_schedule == "linear":
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
|
||||
Reference in New Issue
Block a user