1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Use expand instead of ones to broadcast tensor (#373)

Use `expand` instead of ones to broadcast tensor.

As suggested by @bes-dev. According the documentation this shouldn't
take any memory - it just plays with the strides.
This commit is contained in:
Pedro Cuenca
2022-09-06 17:36:32 +02:00
committed by GitHub
parent 7a1229fa29
commit 56c003705f

View File

@@ -152,7 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)