From 0c11c8c1ac4c3520fefaa3e36638634a5d69b790 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 19 Jun 2025 13:36:02 +0200 Subject: [PATCH] [CI] Fix SANA tests (#11756) update --- tests/pipelines/sana/test_sana_controlnet.py | 3 ++- tests/pipelines/sana/test_sana_sprint_img2img.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/sana/test_sana_controlnet.py b/tests/pipelines/sana/test_sana_controlnet.py index 803f608ba6..9b5c9e439e 100644 --- a/tests/pipelines/sana/test_sana_controlnet.py +++ b/tests/pipelines/sana/test_sana_controlnet.py @@ -30,6 +30,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, to_np @@ -151,7 +152,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): else: generator = torch.Generator(device=device).manual_seed(seed) - control_image = torch.randn(1, 3, 32, 32, generator=generator) + control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device) inputs = { "prompt": "", "negative_prompt": "", diff --git a/tests/pipelines/sana/test_sana_sprint_img2img.py b/tests/pipelines/sana/test_sana_sprint_img2img.py index 1179346d4c..c0e4bf8e35 100644 --- a/tests/pipelines/sana/test_sana_sprint_img2img.py +++ b/tests/pipelines/sana/test_sana_sprint_img2img.py @@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, @@ -137,7 +138,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase) generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - image = torch.randn(1, 3, 32, 32, generator=generator) + image = randn_tensor((1, 3, 32, 32), generator=generator, device=device) inputs = { "prompt": "", "image": image,