mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -30,6 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
@@ -151,7 +152,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
control_image = torch.randn(1, 3, 32, 32, generator=generator)
|
||||
control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
|
||||
@@ -24,6 +24,7 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
@@ -137,7 +138,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
image = torch.randn(1, 3, 32, 32, generator=generator)
|
||||
image = randn_tensor((1, 3, 32, 32), generator=generator, device=device)
|
||||
inputs = {
|
||||
"prompt": "",
|
||||
"image": image,
|
||||
|
||||
Reference in New Issue
Block a user