From 42ba85998ffda74b4a737cfe1f506a51328119c7 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 22 Nov 2022 01:11:18 +0100 Subject: [PATCH] scatter_ argument is not called src, but rather value --- src/diffusers/schedulers/scheduling_vq_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index a1b203f4af..83041b1d3e 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -54,7 +54,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen """ batch_size, vector_length = x.shape log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device) - log_x.scatter_(index=x[:, None, :], src=0, dim=1) + log_x.scatter_(index=x[:, None, :], value=0.0, dim=1) return log_x