diff --git a/encyclopedia_vae/modules/vectorquantizer.py b/encyclopedia_vae/modules/vectorquantizer.py index 456139e..5300edb 100644 --- a/encyclopedia_vae/modules/vectorquantizer.py +++ b/encyclopedia_vae/modules/vectorquantizer.py @@ -15,46 +15,70 @@ def __init__(self, num_embeddings: int, embedding_dim: int, beta: float = 0.25): self.D = embedding_dim self.beta = beta - self.embedding = nn.Embedding(self.K, self.D) - self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K) + self.codebook = nn.Embedding(self.K, self.D) + self.codebook.weight.data.uniform_(-1 / self.K, 1 / self.K) - def forward(self, latents: torch.tensor) -> torch.tensor: - latents = latents.permute( - 0, 2, 3, 1 - ).contiguous() # [B x D x H x W] -> [B x H x W x D] - latents_shape = latents.shape - flat_latents = latents.view(-1, self.D) # [BHW x D] + @property + def resolution(self) -> torch.Tensor: + """Compute the resolution of the Vector Quantizer. + + The resolution is the log2 of the number of embeddings divided by + the dimension of the embedding. This is the same as the bitrate by dimension. + + Returns + ------- + torch.Tensor + the log2 of the number of embedding divided by the dimension. + """ + return torch.log2(self.K) / self.D + + def encode(self, latents: torch.Tensor) -> torch.Tensor: + """Encode the latents by nearest neighboors. + + Parameters + ---------- + latents : torch.Tensor + should be channel last! + + Returns + ------- + torch.Tensor + tensor holding the coding indices for latents. + """ + encodings_shape = latents.shape[:-1] + flat_latents = latents.view(-1, self.D) - # Compute L2 distance between latents and embedding weights dist = ( torch.sum(flat_latents**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 * torch.matmul(flat_latents, self.embedding.weight.t()) - ) # [BHW x K] + + torch.sum(self.codebook.weight**2, dim=1) + - 2 * torch.matmul(flat_latents, self.codebook.weight.t()) + ) + + encoding_inds = torch.argmin(dist, dim=1) + return encoding_inds.view(encodings_shape) - # Get the encoding that has the min distance - encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [BHW, 1] + def decode(self, indices: torch.Tensor) -> torch.Tensor: + """Decode the given indices into vectors, using the codebook. - # Convert to one-hot encodings - device = latents.device - encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) - encoding_one_hot.scatter_(1, encoding_inds, 1) # [BHW x K] + Parameters + ---------- + indices : torch.Tensor + tensor holding indices as integers - # Quantize the latents - quantized_latents = torch.matmul( - encoding_one_hot, self.embedding.weight - ) # [BHW, D] - quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D] + Returns + ------- + torch.Tensor + a channel last tensor + """ + return self.codebook(indices) + + def forward(self, latents: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + quantized_latents = self.decode(self.encode(latents)) - # Compute the VQ Losses commitment_loss = functional.mse_loss(quantized_latents.detach(), latents) embedding_loss = functional.mse_loss(quantized_latents, latents.detach()) vq_loss = commitment_loss * self.beta + embedding_loss - # Add the residue back to the latents quantized_latents = latents + (quantized_latents - latents).detach() - - return quantized_latents.permute( - 0, 3, 1, 2 - ).contiguous(), vq_loss # [B x D x H x W] + return quantized_latents, vq_loss