Skip to content

Commit

Permalink
take care of the smoothing of the positions for autoencoder reconstru…
Browse files Browse the repository at this point in the history
…ction loss
  • Loading branch information
lucidrains committed Dec 5, 2023
1 parent ad250b7 commit 2c7d49e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 9 deletions.
56 changes: 48 additions & 8 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from math import ceil

import torch
from torch import nn, Tensor, einsum
from torch.nn import Module, ModuleList
Expand Down Expand Up @@ -35,6 +37,9 @@ def default(v, d):
def divisible_by(num, den):
return (num % den) == 0

def l1norm(t):
return F.normalize(t, dim = -1, p = 1)

# tensor helper functions

@beartype
Expand Down Expand Up @@ -69,6 +74,27 @@ def undiscretize_coors(
t /= num_discrete
return t * (hi - lo) + lo

@beartype
def gaussian_blur_1d(
t: Tensor,
*,
sigma: float = 1.
) -> Tensor:

_, channels, _, device = *t.shape, t.device

width = int(ceil(sigma * 5))
width += (width + 1) % 2
half_width = width // 2

distance = torch.arange(-half_width, half_width + 1, dtype = torch.float, device = device)

gaussian = torch.exp(-(distance ** 2) / (2 * sigma ** 2))
gaussian = l1norm(gaussian)

kernel = repeat(gaussian, 'n -> c 1 n', c = channels)
return F.conv1d(t, kernel, padding = half_width, groups = channels)

# resnet block

class Block(Module):
Expand Down Expand Up @@ -113,10 +139,11 @@ def __init__(
encoder_depth = 2,
decoder_depth = 2,
dim_codebook = 192,
num_quantizers = 2, # or 'D' in the paper
codebook_size = 16384, # they use 16k, shared codebook between layers
num_quantizers = 2, # or 'D' in the paper
codebook_size = 16384, # they use 16k, shared codebook between layers
rq_kwargs: dict = dict(),
commit_loss_weight = 0.1,
bin_smooth_blur_sigma = 0.4, # they blur the one hot discretized coordinate positions
):
super().__init__()

Expand Down Expand Up @@ -147,8 +174,6 @@ def __init__(
**rq_kwargs
)

self.commit_loss_weight = commit_loss_weight

self.project_codebook_out = nn.Linear(dim_codebook * 3, dim)

self.decoders = ModuleList([])
Expand All @@ -162,6 +187,11 @@ def __init__(
Rearrange('... (v c) -> ... v c', v = 9)
)

# loss related

self.commit_loss_weight = commit_loss_weight
self.bin_smooth_blur_sigma = bin_smooth_blur_sigma

@beartype
def encode(
self,
Expand Down Expand Up @@ -354,13 +384,23 @@ def forward(
decode = self.decode(quantized)

pred_face_coords = self.to_coor_logits(decode)
pred_face_coords = rearrange(pred_face_coords, 'b ... c -> b c (...)')

# reconstruction loss on discretized coordinates on each face
# they also smooth (blur) the one hot positions, localized label smoothing basically

recon_loss = F.cross_entropy(
rearrange(pred_face_coords, 'b ... c -> b c ...'),
face_coordinates
)
face_coordinates = rearrange(face_coordinates, 'b ... -> b 1 (...)')

pred_log_prob = pred_face_coords.log_softmax(dim = 1)

target_one_hot = torch.zeros_like(pred_log_prob).scatter(1, face_coordinates, 1.)

if self.bin_smooth_blur_sigma >= 0.:
target_one_hot = gaussian_blur_1d(target_one_hot, sigma = self.bin_smooth_blur_sigma)

# cross entropy with localized smoothing

recon_loss = (-target_one_hot * pred_log_prob).sum(dim = 1).mean()

# calculate total loss

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'meshgpt-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'MeshGPT Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2c7d49e

Please sign in to comment.