mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
finish pndm sampler
This commit is contained in:
@@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||
|
||||
# 3. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image
|
||||
# 2. compute previous image: x_t -> t_t-1
|
||||
image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
image = image.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
|
||||
image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||
|
||||
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
|
||||
# for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
|
||||
# model_output = self.unet(image, t)["sample"]
|
||||
#
|
||||
# image = self.scheduler.step_prk(model_output, t, image, i=i)["prev_sample"]
|
||||
#
|
||||
# for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
|
||||
# model_output = self.unet(image, t)["sample"]
|
||||
#
|
||||
# image = self.scheduler.step_plms(model_output, t, image, i=i)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
||||
|
||||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
model_output = self.model(sample, sigma_t)
|
||||
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
|
||||
model_output = self.model(sample, sigma_t)["sample"]
|
||||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
||||
|
||||
with torch.no_grad():
|
||||
model_output = model(sample, sigma_t)
|
||||
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
|
||||
# prediction step
|
||||
model_output = model(sample, sigma_t)["sample"]
|
||||
output = self.scheduler.step_pred(model_output, t, sample)
|
||||
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
sample = sample.clamp(0, 1)
|
||||
|
||||
@@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
eta,
|
||||
use_clipped_model_output=False,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
):
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
|
||||
@@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
variance_type="fixed_small",
|
||||
clip_sample=True,
|
||||
tensor_format="pt",
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
import pdb
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# running values
|
||||
self.cur_model_output = 0
|
||||
self.counter = 0
|
||||
self.cur_sample = None
|
||||
self.ets = []
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self.prk_timesteps = None
|
||||
self.plms_timesteps = None
|
||||
self.timesteps = None
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = list(
|
||||
self._timesteps = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
|
||||
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
|
||||
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
|
||||
self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1]))
|
||||
self.plms_timesteps = list(reversed(self._timesteps[:-3]))
|
||||
self.timesteps = self.prk_timesteps + self.plms_timesteps
|
||||
|
||||
self.counter = 0
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
if self.counter < len(self.prk_timesteps):
|
||||
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
|
||||
else:
|
||||
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
|
||||
|
||||
def step_prk(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
num_inference_steps,
|
||||
):
|
||||
"""
|
||||
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
||||
solution to the differential equation.
|
||||
"""
|
||||
t = timestep
|
||||
prk_time_steps = self.prk_timesteps
|
||||
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
|
||||
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
|
||||
timestep = self.prk_timesteps[self.counter // 4 * 4]
|
||||
|
||||
t_orig = prk_time_steps[t // 4 * 4]
|
||||
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
|
||||
|
||||
if t % 4 == 0:
|
||||
if self.counter % 4 == 0:
|
||||
self.cur_model_output += 1 / 6 * model_output
|
||||
self.ets.append(model_output)
|
||||
self.cur_sample = sample
|
||||
elif (t - 1) % 4 == 0:
|
||||
elif (self.counter - 1) % 4 == 0:
|
||||
self.cur_model_output += 1 / 3 * model_output
|
||||
elif (t - 2) % 4 == 0:
|
||||
elif (self.counter - 2) % 4 == 0:
|
||||
self.cur_model_output += 1 / 3 * model_output
|
||||
elif (t - 3) % 4 == 0:
|
||||
elif (self.counter - 3) % 4 == 0:
|
||||
model_output = self.cur_model_output + 1 / 6 * model_output
|
||||
self.cur_model_output = 0
|
||||
|
||||
# cur_sample should not be `None`
|
||||
cur_sample = self.cur_sample if self.cur_sample is not None else sample
|
||||
|
||||
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, model_output)}
|
||||
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
|
||||
self.counter += 1
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
|
||||
def step_plms(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
num_inference_steps,
|
||||
):
|
||||
"""
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
times to approximate the solution.
|
||||
"""
|
||||
t = timestep
|
||||
if len(self.ets) < 3:
|
||||
raise ValueError(
|
||||
f"{self.__class__} can only be run AFTER scheduler has been run "
|
||||
@@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"for more information."
|
||||
)
|
||||
|
||||
timesteps = self.plms_timesteps
|
||||
|
||||
t_orig = timesteps[t]
|
||||
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
|
||||
self.ets.append(model_output)
|
||||
|
||||
model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, model_output)}
|
||||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
|
||||
self.counter += 1
|
||||
|
||||
def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output):
|
||||
return {"prev_sample": prev_sample}
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
@@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# sample -> x_t
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[t_orig + 1]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1]
|
||||
alpha_prod_t = self.alphas_cumprod[timestep + 1]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import pdb
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
@@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
@@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
@@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
@@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
@@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
@@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
for i, t in enumerate(scheduler.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
|
||||
|
||||
for i, t in enumerate(scheduler.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
@@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
model_output = model(sample, sigma_t)
|
||||
|
||||
output = scheduler.step_pred(model_output, t, sample, **kwargs)
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
sample, _ = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
Reference in New Issue
Block a user