mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix typo in BlurDownsample
This commit is contained in:
@@ -126,7 +126,10 @@ class BlurDownsample(torch.nn.Module):
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, _, _ = x.shape
|
||||
x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
|
||||
x = self._apply_2d(x)
|
||||
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user