mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update UniPC to support 1D diffusion. (#5199)
* Update Unipc einsum to support 1D and 3D diffusion. * Add unittest * Update unittest & edge case * Fix unittest * Fix testing_utils.py * Fix unittest file --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -282,13 +282,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, height, width = sample.shape
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * height * width)
|
||||
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
@@ -300,7 +300,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, height, width)
|
||||
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
@@ -534,14 +534,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
|
||||
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - alpha_t * B_h * pred_res
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
|
||||
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
|
||||
else:
|
||||
pred_res = 0
|
||||
x_t = x_t_ - sigma_t * B_h * pred_res
|
||||
@@ -670,7 +670,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
|
||||
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
@@ -678,7 +678,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
||||
if D1s is not None:
|
||||
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
|
||||
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
|
||||
else:
|
||||
corr_res = 0
|
||||
D1_t = model_t - m0
|
||||
|
||||
@@ -269,3 +269,113 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}"
|
||||
assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}"
|
||||
|
||||
|
||||
class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest):
|
||||
@property
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
width = 8
|
||||
|
||||
sample = torch.rand((batch_size, num_channels, width))
|
||||
|
||||
return sample
|
||||
|
||||
@property
|
||||
def dummy_noise_deter(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
width = 8
|
||||
|
||||
num_elems = batch_size * num_channels * width
|
||||
sample = torch.arange(num_elems).flip(-1)
|
||||
sample = sample.reshape(num_channels, width, batch_size)
|
||||
sample = sample / num_elems
|
||||
sample = sample.permute(2, 0, 1)
|
||||
|
||||
return sample
|
||||
|
||||
@property
|
||||
def dummy_sample_deter(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
width = 8
|
||||
|
||||
num_elems = batch_size * num_channels * width
|
||||
sample = torch.arange(num_elems)
|
||||
sample = sample.reshape(num_channels, width, batch_size)
|
||||
sample = sample / num_elems
|
||||
sample = sample.permute(2, 0, 1)
|
||||
|
||||
return sample
|
||||
|
||||
def test_switch(self):
|
||||
# make sure that iterating over schedulers with same config names gives same results
|
||||
# for defaults
|
||||
scheduler = UniPCMultistepScheduler(**self.get_scheduler_config())
|
||||
sample = self.full_loop(scheduler=scheduler)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2441) < 1e-3
|
||||
|
||||
scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config)
|
||||
scheduler = DEISMultistepScheduler.from_config(scheduler.config)
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
|
||||
sample = self.full_loop(scheduler=scheduler)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2441) < 1e-3
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
sample = self.full_loop()
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2441) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras(self):
|
||||
sample = self.full_loop(use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2898) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.1014) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras_and_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.1944) < 1e-3
|
||||
|
||||
def test_full_loop_with_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
t_start = 8
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# add noise
|
||||
noise = self.dummy_noise_deter
|
||||
timesteps = scheduler.timesteps[t_start * scheduler.order :]
|
||||
sample = scheduler.add_noise(sample, noise, timesteps[:1])
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}"
|
||||
assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"
|
||||
|
||||
Reference in New Issue
Block a user