diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index a1b203f4af..83041b1d3e 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -54,7 +54,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen """ batch_size, vector_length = x.shape log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device) - log_x.scatter_(index=x[:, None, :], src=0, dim=1) + log_x.scatter_(index=x[:, None, :], value=0.0, dim=1) return log_x