mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Use expand instead of ones to broadcast tensor (#373)
Use `expand` instead of ones to broadcast tensor. As suggested by @bes-dev. According the documentation this shouldn't take any memory - it just plays with the strides.
This commit is contained in:
@@ -152,7 +152,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
Reference in New Issue
Block a user