mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
explicit broadcasts for assignments (#3535)
This commit is contained in:
@@ -433,7 +433,8 @@ class KDownsample2D(nn.Module):
|
||||
x = F.pad(x, (self.pad,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(x, weight, stride=2)
|
||||
|
||||
|
||||
@@ -449,7 +450,8 @@ class KUpsample2D(nn.Module):
|
||||
x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(x.shape[1], device=x.device)
|
||||
weight[indices, indices] = self.kernel.to(weight)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user