From a9109dbb2be673437096a0d52580826b4cbc401c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 1 Dec 2022 13:25:21 +0000 Subject: [PATCH] up --- tests/test_scheduler.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 28b7bf7b05..e75d384012 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -2026,6 +2026,8 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest): self.check_over_configs(beta_schedule=schedule) def test_full_loop_no_noise(self): + if torch_device == "mps": + return scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -2056,6 +2058,8 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest): assert abs(result_mean.item() - 0.0266) < 1e-3 def test_full_loop_device(self): + if torch_device == "mps": + return scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -2080,9 +2084,6 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest): # The following sum varies between 148 and 156 on mps. Why? assert abs(result_sum.item() - 20.4125) < 1e-2 assert abs(result_mean.item() - 0.0266) < 1e-3 - elif str(torch_device).startswith("mps"): - # Larger tolerance on mps - assert abs(result_mean.item() - 0.0266) < 1e-3 else: # CUDA assert abs(result_sum.item() - 20.4125) < 1e-2 @@ -2117,17 +2118,15 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest): self.check_over_configs(beta_schedule=schedule) def test_full_loop_no_noise(self): + if torch_device == "mps": + return scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(self.num_inference_steps) - if torch_device == "mps": - # device type MPS is not supported for torch.Generator() api. - generator = torch.manual_seed(0) - else: - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.Generator(device=torch_device).manual_seed(0) model = self.dummy_model() sample = self.dummy_sample_deter * scheduler.init_noise_sigma @@ -2153,6 +2152,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest): assert abs(result_mean.item() - 18.1159) < 5e-3 def test_full_loop_device(self): + if torch_device == "mps": + return scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -2182,9 +2183,6 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest): if str(torch_device).startswith("cpu"): assert abs(result_sum.item() - 13849.3945) < 1e-2 assert abs(result_mean.item() - 18.0331) < 5e-3 - elif str(torch_device).startswith("mps"): - # Larger tolerance on mps - assert abs(result_mean.item() - 18.0331) < 1e-2 else: # CUDA assert abs(result_sum.item() - 13913.0459) < 1e-2