diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index c95ea43e55..0bf03f057f 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -576,7 +576,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: - rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) else: D1s = None @@ -714,7 +714,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: - rhos_c = torch.linalg.solve(R, b) + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py index fa34ef75b5..a591a60105 100644 --- a/tests/schedulers/test_scheduler_unipc.py +++ b/tests/schedulers/test_scheduler_unipc.py @@ -229,20 +229,29 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest): assert abs(result_mean.item() - 0.1966) < 1e-3 def test_fp16_support(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) - scheduler = scheduler_class(**scheduler_config) + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config( + thresholding=True, + dynamic_thresholding_ratio=0, + prediction_type=prediction_type, + solver_order=order, + solver_type=solver_type, + ) + scheduler = scheduler_class(**scheduler_config) - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter.half() - scheduler.set_timesteps(num_inference_steps) + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.half() + scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(scheduler.timesteps): - residual = model(sample, t) - sample = scheduler.step(residual, t, sample).prev_sample + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample - assert sample.dtype == torch.float16 + assert sample.dtype == torch.float16 def test_full_loop_with_noise(self): scheduler_class = self.scheduler_classes[0]