From fdef40ba03bfbb02ffb53879f4daaccaddbdad07 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 22 Nov 2022 00:57:19 +0100 Subject: [PATCH] Woops --- src/diffusers/schedulers/scheduling_vq_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index 33ba8d6312..a1b203f4af 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -53,7 +53,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen Log onehot vectors """ batch_size, vector_length = x.shape - log_x = torch.FloatTensor((batch_size, num_classes, vector_length), fill_value=1e-30, device=x.device) + log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device) log_x.scatter_(index=x[:, None, :], src=0, dim=1) return log_x