From 2312b27f796874658bc7391dd5d5c58b71dde153 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 3 Dec 2024 00:33:56 +0100 Subject: [PATCH] Interpolate fix on cuda for large output tensors (#10067) * Workaround for upscale with large output tensors. Fixes #10040. * Fix scale when output_size is given * Style --------- Co-authored-by: Sayak Paul --- src/diffusers/models/upsampling.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index cf07e45b0c..af04ae4b93 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -165,6 +165,14 @@ class Upsample2D(nn.Module): # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` if self.interpolate: + # upsample_nearest_nhwc also fails when the number of output elements is large + # https://github.com/pytorch/pytorch/issues/141831 + scale_factor = ( + 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])]) + ) + if hidden_states.numel() * scale_factor > pow(2, 31): + hidden_states = hidden_states.contiguous() + if output_size is None: hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") else: