mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix _upsample_2d (#535)
* Fix _upsample_2d Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -149,7 +149,6 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
stride = (factor, factor)
|
||||
# Determine data dimensions.
|
||||
stride = [1, 1, factor, factor]
|
||||
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
|
||||
output_padding = (
|
||||
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
|
||||
@@ -161,7 +160,7 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
# Transpose weights.
|
||||
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
||||
weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
|
||||
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
|
||||
|
||||
Reference in New Issue
Block a user