From 2457599114b8dea7455745fe8b032bbf784b974b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 2 Oct 2023 17:53:17 +0000 Subject: [PATCH] make fix copies --- src/diffusers/schedulers/scheduling_unipc_multistep.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index d61341cee7..18d95fe514 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -282,13 +282,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): https://arxiv.org/abs/2205.11487 """ dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape + batch_size, channels, height, width = sample.shape if dtype not in (torch.float32, torch.float64): sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + sample = sample.reshape(batch_size, channels * height * width) abs_sample = sample.abs() # "a certain percentile absolute pixel value" @@ -300,7 +300,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.reshape(batch_size, channels, height, width) sample = sample.to(dtype) return sample