From 85eff637aad1106f593d7535ec41cdb736b0b2ea Mon Sep 17 00:00:00 2001 From: Will Berman Date: Fri, 19 May 2023 10:45:56 -0700 Subject: [PATCH] [{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479) explicit view kernel size as number elements in flattened indices --- src/diffusers/models/unet_1d_blocks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index a0f0e58f91..934a4a4a7d 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -300,7 +300,8 @@ class Downsample1d(nn.Module): hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel return F.conv1d(hidden_states, weight, stride=2) @@ -316,7 +317,8 @@ class Upsample1d(nn.Module): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)