mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Stochastic Sampler][Slow Test]: Cuda test fixes (#3257)
[Slow Test]: Cuda test fixes Co-authored-by: njindal <njindal@adobe.com>
This commit is contained in:
@@ -65,6 +65,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
assert abs(result_sum.item() - 171.59352111816406) < 1e-2
|
||||
assert abs(result_mean.item() - 0.22342906892299652) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
|
||||
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
|
||||
@@ -94,6 +97,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
|
||||
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
assert abs(result_sum.item() - 128.1663360595703) < 1e-2
|
||||
assert abs(result_mean.item() - 0.16688326001167297) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum.item() - 119.8487548828125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
|
||||
@@ -122,6 +128,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
|
||||
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
|
||||
elif torch_device in ["cuda"]:
|
||||
assert abs(result_sum.item() - 171.59353637695312) < 1e-2
|
||||
assert abs(result_mean.item() - 0.22342908382415771) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
|
||||
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
|
||||
@@ -151,6 +160,9 @@ class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
elif torch_device in ["cuda"]:
|
||||
assert abs(result_sum.item() - 177.63653564453125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
else:
|
||||
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user