1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Stochastic Sampler][Slow Test]: Cuda test fixes (#3257)

[Slow Test]: Cuda test fixes

Co-authored-by: njindal <njindal@adobe.com>
This commit is contained in:
Nipun Jindal
2023-04-27 14:52:38 +05:30
committed by Daniel Gu
parent 7880ed77fb
commit 8def721ec8

View File

@@ -65,6 +65,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
@@ -94,6 +97,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
else:
assert abs(result_sum.item() - 119.8487548828125) < 1e-2
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
@@ -122,6 +128,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
@@ -151,6 +160,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
if torch_device in ["mps"]:
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
else:
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2