mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
UniPC Multistep fix tensor dtype/device on order=3 (#7532)
* UniPC UTs iterate solvers on FP16 It wasn't catching errs on order==3. Might be excessive? * UniPC Multistep fix tensor dtype/device on order=3 * UniPC UTs Add v_pred to fp16 test iter For completions sake. Probably overkill?
This commit is contained in:
@@ -576,7 +576,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
||||
else:
|
||||
D1s = None
|
||||
|
||||
@@ -714,7 +714,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b)
|
||||
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
||||
|
||||
if self.predict_x0:
|
||||
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
||||
|
||||
@@ -229,20 +229,29 @@ class UniPCMultistepSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_mean.item() - 0.1966) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
for order in [1, 2, 3]:
|
||||
for solver_type in ["bh1", "bh2"]:
|
||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(
|
||||
thresholding=True,
|
||||
dynamic_thresholding_ratio=0,
|
||||
prediction_type=prediction_type,
|
||||
solver_order=order,
|
||||
solver_type=solver_type,
|
||||
)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.half()
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.half()
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
assert sample.dtype == torch.float16
|
||||
assert sample.dtype == torch.float16
|
||||
|
||||
def test_full_loop_with_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
|
||||
Reference in New Issue
Block a user