Skip to content

Commit

Permalink
Refacto of the classical VQ module. Introduce dedicated encode and de…
Browse files Browse the repository at this point in the history
…code method.
  • Loading branch information
MisterBourbaki committed Jun 10, 2024
1 parent 5af5b6c commit f288f0d
Showing 1 changed file with 53 additions and 29 deletions.
82 changes: 53 additions & 29 deletions encyclopedia_vae/modules/vectorquantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,70 @@ def __init__(self, num_embeddings: int, embedding_dim: int, beta: float = 0.25):
self.D = embedding_dim
self.beta = beta

self.embedding = nn.Embedding(self.K, self.D)
self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
self.codebook = nn.Embedding(self.K, self.D)
self.codebook.weight.data.uniform_(-1 / self.K, 1 / self.K)

def forward(self, latents: torch.tensor) -> torch.tensor:
latents = latents.permute(
0, 2, 3, 1
).contiguous() # [B x D x H x W] -> [B x H x W x D]
latents_shape = latents.shape
flat_latents = latents.view(-1, self.D) # [BHW x D]
@property
def resolution(self) -> torch.Tensor:
"""Compute the resolution of the Vector Quantizer.
The resolution is the log2 of the number of embeddings divided by
the dimension of the embedding. This is the same as the bitrate by dimension.
Returns
-------
torch.Tensor
the log2 of the number of embedding divided by the dimension.
"""
return torch.log2(self.K) / self.D

def encode(self, latents: torch.Tensor) -> torch.Tensor:
"""Encode the latents by nearest neighboors.
Parameters
----------
latents : torch.Tensor
should be channel last!
Returns
-------
torch.Tensor
tensor holding the coding indices for latents.
"""
encodings_shape = latents.shape[:-1]
flat_latents = latents.view(-1, self.D)

# Compute L2 distance between latents and embedding weights
dist = (
torch.sum(flat_latents**2, dim=1, keepdim=True)
+ torch.sum(self.embedding.weight**2, dim=1)
- 2 * torch.matmul(flat_latents, self.embedding.weight.t())
) # [BHW x K]
+ torch.sum(self.codebook.weight**2, dim=1)
- 2 * torch.matmul(flat_latents, self.codebook.weight.t())
)

encoding_inds = torch.argmin(dist, dim=1)
return encoding_inds.view(encodings_shape)

# Get the encoding that has the min distance
encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1]
def decode(self, indices: torch.Tensor) -> torch.Tensor:
"""Decode the given indices into vectors, using the codebook.
# Convert to one-hot encodings
device = latents.device
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device)
encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K]
Parameters
----------
indices : torch.Tensor
tensor holding indices as integers
# Quantize the latents
quantized_latents = torch.matmul(
encoding_one_hot, self.embedding.weight
) # [BHW, D]
quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D]
Returns
-------
torch.Tensor
a channel last tensor
"""
return self.codebook(indices)

def forward(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
quantized_latents = self.decode(self.encode(latents))

# Compute the VQ Losses
commitment_loss = functional.mse_loss(quantized_latents.detach(), latents)
embedding_loss = functional.mse_loss(quantized_latents, latents.detach())

vq_loss = commitment_loss * self.beta + embedding_loss

# Add the residue back to the latents
quantized_latents = latents + (quantized_latents - latents).detach()

return quantized_latents.permute(
0, 3, 1, 2
).contiguous(), vq_loss # [B x D x H x W]
return quantized_latents, vq_loss

0 comments on commit f288f0d

Please sign in to comment.