diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index e29f4e8afa..bcfc8789ab 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -290,15 +290,10 @@ class VectorQuantizer(nn.Module): # reshape z -> (batch, height, width, channel) and flatten z = z.permute(0, 2, 3, 1).contiguous() z_flattened = z.view(-1, self.vq_embed_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()) - ) - - min_encoding_indices = torch.argmin(d, dim=1) z_q = self.embedding(min_encoding_indices).view(z.shape) perplexity = None min_encodings = None