1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
thomasw21
2022-11-22 00:57:19 +01:00
parent fe691feb5a
commit fdef40ba03

View File

@@ -53,7 +53,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
Log onehot vectors
"""
batch_size, vector_length = x.shape
log_x = torch.FloatTensor((batch_size, num_classes, vector_length), fill_value=1e-30, device=x.device)
log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device)
log_x.scatter_(index=x[:, None, :], src=0, dim=1)
return log_x