1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

warmup issue again?

This commit is contained in:
anton-l
2022-11-09 23:47:49 +01:00
parent 85ec36bac7
commit 86d4c5a254
4 changed files with 6 additions and 8 deletions

View File

@@ -136,7 +136,7 @@ jobs:
- name: Run fast PyTorch tests on M1 (MPS)
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}

View File

@@ -78,7 +78,7 @@ class DDIMPipeline(DiffusionPipeline):
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
)
deprecate(

View File

@@ -83,7 +83,7 @@ class DDPMPipeline(DiffusionPipeline):
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f"device `{self.device}`, so the `generator` will be ignored. "
f'Please use `torch.Generator(device="{self.device}")` instead.'
)
deprecate(

View File

@@ -80,7 +80,6 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
@@ -88,11 +87,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
# Warmup pass when using mps (see #372)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
_ = ddpm(num_inference_steps=1)
generator = generator.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]