1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Use ONNX / Core ML compatible method to broadcast (#310)

* Use ONNX / Core ML compatible method to broadcast.

Unfortunately `tile` could not be used either, it's still not compatible
with ONNX.

See #284.

* Add comment about why broadcast_to is not used.

Also, apply style to changed files.

* Make sure broadcast remains in same device.
This commit is contained in:
Pedro Cuenca
2022-09-02 18:22:57 +02:00
committed by GitHub
parent 7b628a225a
commit e49dd03d2d
2 changed files with 4 additions and 8 deletions

View File

@@ -120,7 +120,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@@ -132,8 +131,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)

View File

@@ -121,7 +121,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@@ -133,8 +132,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension
timesteps = timesteps.broadcast_to(sample.shape[0])
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
@@ -145,7 +144,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
@@ -160,7 +158,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
# 5. up
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]