1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Patrick von Platen
2022-12-01 13:25:21 +00:00
parent 6874d2b57f
commit a9109dbb2b

View File

@@ -2026,6 +2026,8 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
self.check_over_configs(beta_schedule=schedule)
def test_full_loop_no_noise(self):
if torch_device == "mps":
return
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@@ -2056,6 +2058,8 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.0266) < 1e-3
def test_full_loop_device(self):
if torch_device == "mps":
return
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@@ -2080,9 +2084,6 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 20.4125) < 1e-2
assert abs(result_mean.item() - 0.0266) < 1e-3
elif str(torch_device).startswith("mps"):
# Larger tolerance on mps
assert abs(result_mean.item() - 0.0266) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 20.4125) < 1e-2
@@ -2117,17 +2118,15 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
self.check_over_configs(beta_schedule=schedule)
def test_full_loop_no_noise(self):
if torch_device == "mps":
return
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
@@ -2153,6 +2152,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 18.1159) < 5e-3
def test_full_loop_device(self):
if torch_device == "mps":
return
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@@ -2182,9 +2183,6 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
if str(torch_device).startswith("cpu"):
assert abs(result_sum.item() - 13849.3945) < 1e-2
assert abs(result_mean.item() - 18.0331) < 5e-3
elif str(torch_device).startswith("mps"):
# Larger tolerance on mps
assert abs(result_mean.item() - 18.0331) < 1e-2
else:
# CUDA
assert abs(result_sum.item() - 13913.0459) < 1e-2