1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[{Up,Down}sample1d] explicit view kernel size as number elements in flattened indices (#3479)

explicit view kernel size as number elements in flattened indices
This commit is contained in:
Will Berman
2023-05-19 10:45:56 -07:00
committed by GitHub
parent e589bdb956
commit 85eff637aa

View File

@@ -300,7 +300,8 @@ class Downsample1d(nn.Module):
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv1d(hidden_states, weight, stride=2)
@@ -316,7 +317,8 @@ class Upsample1d(nn.Module):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)