mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] Fix mps+generator fast tests (#1230)
* [Tests] Fix mps+generator fast tests
* mps for Euler
* retry
* warmup issue again?
* fix reproducible initial noise
* Revert "fix reproducible initial noise"
This reverts commit f300d05cb9.
* fix reproducible initial noise
* fix device
This commit is contained in:
2
.github/workflows/pr_tests.yml
vendored
2
.github/workflows/pr_tests.yml
vendored
@@ -136,7 +136,7 @@ jobs:
|
||||
- name: Run fast PyTorch tests on M1 (MPS)
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
|
||||
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -78,7 +78,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be set to `None`. "
|
||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
||||
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
|
||||
)
|
||||
deprecate(
|
||||
@@ -89,11 +89,13 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
image = image.to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -83,7 +83,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
|
||||
message = (
|
||||
f"The `generator` device is `{generator.device}` and does not match the pipeline "
|
||||
f"device `{self.device}`, so the `generator` will be set to `None`. "
|
||||
f"device `{self.device}`, so the `generator` will be ignored. "
|
||||
f'Please use `torch.Generator(device="{self.device}")` instead.'
|
||||
)
|
||||
deprecate(
|
||||
@@ -94,11 +94,13 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
image = image.to(self.device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -81,10 +81,14 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
generator = generator.manual_seed(0)
|
||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
@@ -1281,7 +1281,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
generator = torch.Generator(torch_device).manual_seed(0)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
@@ -1308,7 +1312,11 @@ class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
generator = torch.Generator(torch_device).manual_seed(0)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
@@ -1364,7 +1372,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
@@ -1381,7 +1393,7 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if str(torch_device).startswith("cpu"):
|
||||
if torch_device in ["cpu", "mps"]:
|
||||
assert abs(result_sum.item() - 152.3192) < 1e-2
|
||||
assert abs(result_mean.item() - 0.1983) < 1e-3
|
||||
else:
|
||||
@@ -1396,7 +1408,11 @@ class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
|
||||
Reference in New Issue
Block a user