From 54af3ca7fd910c8abeeb022ee7583ef725bce85c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 21 May 2025 14:55:54 +0530 Subject: [PATCH] [chore] allow string device to be passed to randn_tensor. (#11559) allow string device to be passed to randn_tensor. --- src/diffusers/utils/torch_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index a970542293..053a3d99b9 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -38,7 +38,7 @@ except (ImportError, ModuleNotFoundError): def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, - device: Optional["torch.device"] = None, + device: Optional[Union[str, "torch.device"]] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): @@ -47,6 +47,8 @@ def randn_tensor( is always created on the CPU. """ # device on which tensor is created defaults to device + if isinstance(device, str): + device = torch.device(device) rand_device = device batch_size = shape[0]