diff --git a/src/diffusers/pipelines/ltx2/latent_upsampler.py b/src/diffusers/pipelines/ltx2/latent_upsampler.py index 7cb84b951a..4347bbf1ed 100644 --- a/src/diffusers/pipelines/ltx2/latent_upsampler.py +++ b/src/diffusers/pipelines/ltx2/latent_upsampler.py @@ -126,7 +126,10 @@ class BlurDownsample(torch.nn.Module): # dims == 3: apply per-frame on H,W b, c, f, _, _ = x.shape x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W] - x = self._apply_2d(x) + + weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise + x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c) + h2, w2 = x.shape[-2:] x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W] return x