mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[chore] allow string device to be passed to randn_tensor. (#11559)
allow string device to be passed to randn_tensor.
This commit is contained in:
@@ -38,7 +38,7 @@ except (ImportError, ModuleNotFoundError):
|
||||
def randn_tensor(
|
||||
shape: Union[Tuple, List],
|
||||
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
device: Optional[Union[str, "torch.device"]] = None,
|
||||
dtype: Optional["torch.dtype"] = None,
|
||||
layout: Optional["torch.layout"] = None,
|
||||
):
|
||||
@@ -47,6 +47,8 @@ def randn_tensor(
|
||||
is always created on the CPU.
|
||||
"""
|
||||
# device on which tensor is created defaults to device
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
rand_device = device
|
||||
batch_size = shape[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user