diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e9aeb5eb06..37641429f6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -152,7 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) + timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb)