From 56c003705f510700e4687edfdebfb953325c2fab Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 17:36:32 +0200 Subject: [PATCH] Use `expand` instead of ones to broadcast tensor (#373) Use `expand` instead of ones to broadcast tensor. As suggested by @bes-dev. According the documentation this shouldn't take any memory - it just plays with the strides. --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e9aeb5eb06..37641429f6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -152,7 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb)