mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Use ONNX / Core ML compatible method to broadcast (#310)
* Use ONNX / Core ML compatible method to broadcast. Unfortunately `tile` could not be used either, it's still not compatible with ONNX. See #284. * Add comment about why broadcast_to is not used. Also, apply style to changed files. * Make sure broadcast remains in same device.
This commit is contained in:
@@ -120,7 +120,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -132,8 +131,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension
|
||||
timesteps = timesteps.broadcast_to(sample.shape[0])
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
@@ -121,7 +121,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -133,8 +132,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension
|
||||
timesteps = timesteps.broadcast_to(sample.shape[0])
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
emb = self.time_embedding(t_emb)
|
||||
@@ -145,7 +144,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||
@@ -160,7 +158,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 5. up
|
||||
for upsample_block in self.up_blocks:
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user