mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Woops
This commit is contained in:
@@ -53,7 +53,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
|
||||
Log onehot vectors
|
||||
"""
|
||||
batch_size, vector_length = x.shape
|
||||
log_x = torch.FloatTensor((batch_size, num_classes, vector_length), fill_value=1e-30, device=x.device)
|
||||
log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device)
|
||||
log_x.scatter_(index=x[:, None, :], src=0, dim=1)
|
||||
return log_x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user