1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Compute embedding distances with torch.cdist (#1459)

small but mighty
This commit is contained in:
Benjamin Lefaudeux
2022-12-05 12:37:05 +01:00
committed by GitHub
parent 513fc68104
commit 720dbfc985

View File

@@ -290,15 +290,10 @@ class VectorQuantizer(nn.Module):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
)
min_encoding_indices = torch.argmin(d, dim=1)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
min_encodings = None