From c8b0f0eb21354df3b920e75fc1517d4d06d757b4 Mon Sep 17 00:00:00 2001 From: Leng Yue Date: Mon, 2 Oct 2023 10:17:46 -0700 Subject: [PATCH] Update UniPC to support 1D diffusion. (#5199) * Update Unipc einsum to support 1D and 3D diffusion. * Add unittest * Update unittest & edge case * Fix unittest * Fix testing_utils.py * Fix unittest file --------- Co-authored-by: Patrick von Platen --- .../schedulers/scheduling_unipc_multistep.py | 14 +-- tests/schedulers/test_scheduler_unipc.py | 110 ++++++++++++++++++ 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 2b5bd4fd60..d61341cee7 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -282,13 +282,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype - batch_size, channels, height, width = sample.shape + batch_size, channels, *remaining_dims = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * height * width) + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) abs_sample = sample.abs() # "a certain percentile absolute pixel value" @@ -300,7 +300,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - sample = sample.reshape(batch_size, channels, height, width) + sample = sample.reshape(batch_size, channels, *remaining_dims) sample = sample.to(dtype) return sample @@ -534,14 +534,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: - pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: - pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res @@ -670,7 +670,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: - corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 @@ -678,7 +678,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: - corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index 1b9a464ba6..be41cea95b 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -269,3 +269,113 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" + + +class UniPCMultistepScheduler1DTest(UniPCMultistepSchedulerTest): + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + width = 8 + + sample = torch.rand((batch_size, num_channels, width)) + + return sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + + return sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + + return sample + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler = UniPCMultistepScheduler(**self.get_scheduler_config()) + sample = self.full_loop(scheduler=scheduler) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + scheduler = DPMSolverSinglestepScheduler.from_config(scheduler.config) + scheduler = DEISMultistepScheduler.from_config(scheduler.config) + scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config) + scheduler = UniPCMultistepScheduler.from_config(scheduler.config) + + sample = self.full_loop(scheduler=scheduler) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_with_karras(self): + sample = self.full_loop(use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2898) < 1e-3 + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.1944) < 1e-3 + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + t_start = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + # add noise + noise = self.dummy_noise_deter + timesteps = scheduler.timesteps[t_start * scheduler.order :] + sample = scheduler.add_noise(sample, noise, timesteps[:1]) + + for i, t in enumerate(timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" + assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}"