mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix ddpm scheduler
This commit is contained in:
@@ -51,13 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
# 3. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -101,7 +101,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)[::-1].copy()
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def get_variance(self, t, variance_type=None):
|
||||
def _get_variance(self, t, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
|
||||
@@ -133,6 +133,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
predict_epsilon=True,
|
||||
generator=None,
|
||||
):
|
||||
t = timestep
|
||||
# 1. compute alphas, betas
|
||||
@@ -161,6 +162,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
||||
|
||||
# 6. Add noise
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = torch.randn(model_output.shape, generator=generator).to(model_output.device)
|
||||
variance = self._get_variance(t).sqrt() * noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
return {"prev_sample": pred_prev_sample}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
|
||||
Reference in New Issue
Block a user