1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

finish pndm sampler

This commit is contained in:
Patrick von Platen
2022-07-21 01:50:12 +00:00
parent fe98574622
commit 394243ce98
7 changed files with 77 additions and 70 deletions

View File

@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]
# 2. predict previous mean of image x_t-1
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
# 3. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image
# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline):
image = image.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
for t in tqdm(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"]
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
image = self.scheduler.step(model_output, t, image)["prev_sample"]
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
model_output = self.unet(image, t)["sample"]
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
# for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
# model_output = self.unet(image, t)["sample"]
#
# image = self.scheduler.step_prk(model_output, t, image, i=i)["prev_sample"]
#
# for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
# model_output = self.unet(image, t)["sample"]
#
# image = self.scheduler.step_plms(model_output, t, image, i=i)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline):
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
# correction step
for _ in range(self.scheduler.correct_steps):
model_output = self.model(sample, sigma_t)
if isinstance(model_output, dict):
model_output = model_output["sample"]
model_output = self.model(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
with torch.no_grad():
model_output = model(sample, sigma_t)
if isinstance(model_output, dict):
model_output = model_output["sample"]
# prediction step
model_output = model(sample, sigma_t)["sample"]
output = self.scheduler.step_pred(model_output, t, sample)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
sample = sample.clamp(0, 1)

View File

@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
eta,
use_clipped_model_output=False,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf

View File

@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
variance_type="fixed_small",
clip_sample=True,
tensor_format="pt",

View File

@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import pdb
from typing import Union
import numpy as np
@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# running values
self.cur_model_output = 0
self.counter = 0
self.cur_sample = None
self.ets = []
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.prk_timesteps = None
self.plms_timesteps = None
self.timesteps = None
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = list(
self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
self.timesteps = self.prk_timesteps + self.plms_timesteps
self.counter = 0
self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
if self.counter < len(self.prk_timesteps):
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
def step_prk(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
"""
t = timestep
prk_time_steps = self.prk_timesteps
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
timestep = self.prk_timesteps[self.counter // 4 * 4]
t_orig = prk_time_steps[t // 4 * 4]
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
if t % 4 == 0:
if self.counter % 4 == 0:
self.cur_model_output += 1 / 6 * model_output
self.ets.append(model_output)
self.cur_sample = sample
elif (t - 1) % 4 == 0:
elif (self.counter - 1) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output
elif (t - 2) % 4 == 0:
elif (self.counter - 2) % 4 == 0:
self.cur_model_output += 1 / 3 * model_output
elif (t - 3) % 4 == 0:
elif (self.counter - 3) % 4 == 0:
model_output = self.cur_model_output + 1 / 6 * model_output
self.cur_model_output = 0
# cur_sample should not be `None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, model_output)}
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
def step_plms(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
t = timestep
if len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
timesteps = self.plms_timesteps
t_orig = timesteps[t]
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
self.ets.append(model_output)
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, model_output)}
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output):
return {"prev_sample": prev_sample}
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.alphas_cumprod[t_orig + 1]
alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1]
alpha_prod_t = self.alphas_cumprod[timestep + 1]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

View File

@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pdb
import tempfile
import unittest
@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
new_scheduler.set_timesteps(num_inference_steps)
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
model_output = model(sample, sigma_t)
output = scheduler.step_pred(model_output, t, sample, **kwargs)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
sample, _ = output["prev_sample"], output["prev_sample_mean"]
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))