1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

explicit broadcasts for assignments (#3535)

This commit is contained in:
Will Berman
2023-05-24 03:17:41 -07:00
committed by GitHub
parent c13dbd5c3a
commit db56f8a4f5

View File

@@ -433,7 +433,8 @@ class KDownsample2D(nn.Module):
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv2d(x, weight, stride=2)
@@ -449,7 +450,8 @@ class KUpsample2D(nn.Module):
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
indices = torch.arange(x.shape[1], device=x.device)
weight[indices, indices] = self.kernel.to(weight)
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
weight[indices, indices] = kernel
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)