1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update UniPC to support 1D diffusion. (#5199)

* Update Unipc einsum to support 1D and 3D diffusion.

* Add unittest

* Update unittest & edge case

* Fix unittest

* Fix testing_utils.py

* Fix unittest file

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Leng Yue
2023-10-02 10:17:46 -07:00
committed by GitHub
parent 7a4324cce3
commit c8b0f0eb21
2 changed files with 117 additions and 7 deletions

View File

@@ -282,13 +282,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
https://arxiv.org/abs/2205.11487
"""
dtype = sample.dtype
batch_size, channels, height, width = sample.shape
batch_size, channels, *remaining_dims = sample.shape
if dtype not in (torch.float32, torch.float64):
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
# Flatten sample for doing quantile calculation along each image
sample = sample.reshape(batch_size, channels * height * width)
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
@@ -300,7 +300,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
sample = sample.reshape(batch_size, channels, height, width)
sample = sample.reshape(batch_size, channels, *remaining_dims)
sample = sample.to(dtype)
return sample
@@ -534,14 +534,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
@@ -670,7 +670,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
@@ -678,7 +678,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0

View File

@@ -269,3 +269,113 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}"
assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}"
class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
width = 8
sample = torch.rand((batch_size, num_channels, width))
return sample
@property
def dummy_noise_deter(self):
batch_size = 4
num_channels = 3
width = 8
num_elems = batch_size * num_channels * width
sample = torch.arange(num_elems).flip(-1)
sample = sample.reshape(num_channels, width, batch_size)
sample = sample / num_elems
sample = sample.permute(2, 0, 1)
return sample
@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
width = 8
num_elems = batch_size * num_channels * width
sample = torch.arange(num_elems)
sample = sample.reshape(num_channels, width, batch_size)
sample = sample / num_elems
sample = sample.permute(2, 0, 1)
return sample
def test_switch(self):
# make sure that iterating over schedulers with same config names gives same results
# for defaults
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2441) < 1e-3
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
sample = self.full_loop(scheduler=scheduler)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2441) < 1e-3
def test_full_loop_no_noise(self):
sample = self.full_loop()
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2441) < 1e-3
def test_full_loop_with_karras(self):
sample = self.full_loop(use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.2898) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1014) < 1e-3
def test_full_loop_with_karras_and_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.1944) < 1e-3
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
t_start = 8
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
# add noise
noise = self.dummy_noise_deter
timesteps = scheduler.timesteps[t_start * scheduler.order :]
sample = scheduler.add_noise(sample, noise, timesteps[:1])
for i, t in enumerate(timesteps):
residual = model(sample, t)
sample = scheduler.step(residual, t, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}"
assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"