mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -2026,6 +2026,8 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -2056,6 +2058,8 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -2080,9 +2084,6 @@ class KDPM2DiscreteSchedulerTest(SchedulerCommonTest):
|
||||
# The following sum varies between 148 and 156 on mps. Why?
|
||||
assert abs(result_sum.item() - 20.4125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
elif str(torch_device).startswith("mps"):
|
||||
# Larger tolerance on mps
|
||||
assert abs(result_mean.item() - 0.0266) < 1e-3
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 20.4125) < 1e-2
|
||||
@@ -2117,17 +2118,15 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
@@ -2153,6 +2152,8 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_mean.item() - 18.1159) < 5e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
if torch_device == "mps":
|
||||
return
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -2182,9 +2183,6 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
if str(torch_device).startswith("cpu"):
|
||||
assert abs(result_sum.item() - 13849.3945) < 1e-2
|
||||
assert abs(result_mean.item() - 18.0331) < 5e-3
|
||||
elif str(torch_device).startswith("mps"):
|
||||
# Larger tolerance on mps
|
||||
assert abs(result_mean.item() - 18.0331) < 1e-2
|
||||
else:
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 13913.0459) < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user